From b28f175b617283bd2c833eb0053560bb1d989cb9 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Wed, 10 Jun 2026 18:48:26 +0200 Subject: [PATCH 1/5] feat: transparent openai responses api integration --- README.md | 37 +++- api/openai.py | 189 +++++++++------- api/responses.py | 398 +++++++++++++++++++++++++++++++++ db.py | 175 ++++++++++++++- requests/responses.py | 492 +++++++++++++++++++++++++++++++++++++++++ router.py | 9 + test/test_responses.py | 460 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 1674 insertions(+), 86 deletions(-) create mode 100644 api/responses.py create mode 100644 requests/responses.py create mode 100644 test/test_responses.py diff --git a/README.md b/README.md index fb60988..52f7b99 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,41 @@ This way the Ollama backend servers are utilized more efficient than by simply u NOMYO Router also supports OpenAI API compatible v1 backend servers. +## OpenAI Responses API + +In addition to Chat Completions, NOMYO Router exposes the OpenAI **Responses API**: + +``` +POST /v1/responses # create a response (stream or non-stream) +GET /v1/responses/{id} # retrieve a stored response +DELETE /v1/responses/{id} # delete a stored response +POST /v1/responses/{id}/cancel # cancel a background response +``` + +It works transparently across **all** backends. When the routed model lives on a native +Responses backend (external OpenAI) the request is forwarded as-is; for Ollama and llama-server the +router translates Responses ⇄ Chat Completions in both directions (request, response, and streaming +typed SSE events), so clients get a consistent `/v1/responses` surface regardless of backend. + +### Conversation state (`store` / `previous_response_id`) + +The router **owns conversation state itself** (persisted in its SQLite DB) rather than delegating to +the upstream provider, so `store` and `previous_response_id` behave identically on every backend. +On a follow-up request the router rehydrates the prior turns from its DB and expands them into the +conversation; outbound native calls always send `store=false`. Trade-off: this forgoes OpenAI's +server-side reasoning-state reuse in exchange for uniform, backend-agnostic chaining. + +### Background mode + +`background:true` (which requires `store:true`) returns immediately with `{"status":"queued"}`; the +request runs server-side and the client polls `GET /v1/responses/{id}` until the status reaches a +terminal state (`completed` / `failed` / `cancelled`). `POST /v1/responses/{id}/cancel` aborts it. + +Limitations: streaming reconnect-resume via `starting_after` is not yet implemented. In a +multi-worker/replica deployment polling works via the shared DB, but `cancel` only reaches the +running task in the worker that started it (other workers just mark the stored row cancelled). A +background task interrupted by a server restart is reconciled to `failed` on the next startup. + ## Semantic LLM Cache NOMYO Router includes an optional semantic cache that serves repeated or semantically similar LLM requests from cache — no endpoint round-trip, no token cost, response in <10 ms. @@ -172,7 +207,7 @@ Each request is keyed on `model + system_prompt` (exact) combined with a weighte ### Cached routes -`/api/chat` · `/api/generate` · `/v1/chat/completions` · `/v1/completions` +`/api/chat` · `/api/generate` · `/v1/chat/completions` · `/v1/completions` · `/v1/responses` ### Cache management diff --git a/api/openai.py b/api/openai.py index 4662f50..ab24f54 100644 --- a/api/openai.py +++ b/api/openai.py @@ -46,6 +46,110 @@ from routing import choose_endpoint, decrement_usage router = APIRouter() +async def create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model): + """Call ``chat.completions.create`` with the router's resilience retries. + + Encapsulates the recovery ladder shared by the chat-completions handler and + the translated ``/v1/responses`` path: + + * ``does not support tools`` → retry without ``tools`` + * llama-server context exhaustion → sliding-window message trim, with a + second retry that also strips ``tools``/``tool_choice`` + * backend connection failure → mark (endpoint, model) unhealthy so the next + request reroutes, then re-raise + * ``image input is not supported`` → strip images and retry + + On unrecoverable failure the endpoint usage counter is decremented and the + exception is re-raised. Returns the established async generator / response. + """ + config = get_config() + try: + async_gen = await oclient.chat.completions.create(**send_params) + except Exception as e: + _e_str = str(e) + _is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str + print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True) + if "does not support tools" in _e_str: + # Model doesn't support tools — retry without them + print(f"[ochat] retry: no tools", flush=True) + try: + params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} + async_gen = await oclient.chat.completions.create(**params_without_tools) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + elif _is_ctx_err: + # Backend context limit hit — apply sliding-window trim (context-shift at message level) + err_body = getattr(e, "body", {}) or {} + err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} + n_ctx_limit = err_detail.get("n_ctx", 0) + actual_tokens = err_detail.get("n_prompt_tokens", 0) + # Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors) + if not n_ctx_limit: + import re as _re + _m = _re.search(r"'n_ctx':\s*(\d+)", _e_str) + if _m: + n_ctx_limit = int(_m.group(1)) + _m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str) + if _m: + actual_tokens = int(_m.group(1)) + print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True) + if not n_ctx_limit: + await decrement_usage(endpoint, tracking_model) + raise + if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = n_ctx_limit + + msgs_to_trim = send_params.get("messages", []) + try: + cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) + trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) + except Exception as _helper_exc: + print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True) + await decrement_usage(endpoint, tracking_model) + raise + dropped = len(msgs_to_trim) - len(trimmed_messages) + print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True) + try: + async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages}) + print(f"[ctx-trim] retry-1 ok", flush=True) + except Exception as e2: + _e2_str = str(e2) + if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: + # Still too large — tool definitions likely consuming too many tokens, strip them too + print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True) + params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")} + try: + async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages}) + print(f"[ctx-trim] retry-2 ok", flush=True) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + elif _is_backend_connection_error(e): + # Upstream connection failed (e.g. llama-server in router mode + # whose delegated worker died). Mark (endpoint, model) so the + # next request reroutes; the client will retry this one. + print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) + await _mark_backend_unhealthy(endpoint, model, _e_str) + await decrement_usage(endpoint, tracking_model) + raise + elif "image input is not supported" in _e_str: + # Model doesn't support images — strip and retry + print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages") + try: + async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))}) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + return async_gen + + @router.post("/v1/embeddings") async def openai_embedding_proxy(request: Request): """ @@ -260,90 +364,7 @@ async def openai_chat_completions_proxy(request: Request): _dropped = len(_pre_msgs) - len(_pre_trimmed) print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) send_params = {**send_params, "messages": _pre_trimmed} - try: - async_gen = await oclient.chat.completions.create(**send_params) - except Exception as e: - _e_str = str(e) - _is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str - print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True) - if "does not support tools" in _e_str: - # Model doesn't support tools — retry without them - print(f"[ochat] retry: no tools", flush=True) - try: - params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} - async_gen = await oclient.chat.completions.create(**params_without_tools) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise - elif _is_ctx_err: - # Backend context limit hit — apply sliding-window trim (context-shift at message level) - err_body = getattr(e, "body", {}) or {} - err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} - n_ctx_limit = err_detail.get("n_ctx", 0) - actual_tokens = err_detail.get("n_prompt_tokens", 0) - # Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors) - if not n_ctx_limit: - import re as _re - _m = _re.search(r"'n_ctx':\s*(\d+)", _e_str) - if _m: - n_ctx_limit = int(_m.group(1)) - _m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str) - if _m: - actual_tokens = int(_m.group(1)) - print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True) - if not n_ctx_limit: - await decrement_usage(endpoint, tracking_model) - raise - if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: - _endpoint_nctx[(endpoint, model)] = n_ctx_limit - - msgs_to_trim = send_params.get("messages", []) - try: - cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) - trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) - except Exception as _helper_exc: - print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True) - await decrement_usage(endpoint, tracking_model) - raise - dropped = len(msgs_to_trim) - len(trimmed_messages) - print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True) - try: - async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages}) - print(f"[ctx-trim] retry-1 ok", flush=True) - except Exception as e2: - _e2_str = str(e2) - if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: - # Still too large — tool definitions likely consuming too many tokens, strip them too - print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True) - params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")} - try: - async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages}) - print(f"[ctx-trim] retry-2 ok", flush=True) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise - else: - await decrement_usage(endpoint, tracking_model) - raise - elif _is_backend_connection_error(e): - # Upstream connection failed (e.g. llama-server in router mode - # whose delegated worker died). Mark (endpoint, model) so the - # next request reroutes; the client will retry this one. - print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) - await _mark_backend_unhealthy(endpoint, model, _e_str) - await decrement_usage(endpoint, tracking_model) - raise - elif "image input is not supported" in _e_str: - # Model doesn't support images — strip and retry - print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages") - try: - async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))}) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise - else: - await decrement_usage(endpoint, tracking_model) - raise + async_gen = await create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model) # 4. Async generator — only streams the already-established async_gen async def stream_ochat_response(): diff --git a/api/responses.py b/api/responses.py new file mode 100644 index 0000000..0a803d3 --- /dev/null +++ b/api/responses.py @@ -0,0 +1,398 @@ +"""OpenAI **Responses API** routes (``/v1/responses`` and its retrieve / delete / +cancel companions). + +The router speaks Chat Completions to its backends, so this layer: + + * **native** (external OpenAI): forwards via ``oclient.responses.create`` and + streams the SDK's typed events straight back, rewriting the response ``id`` to + a router-owned ``resp_`` id so chaining stays router-managed. + * **translated** (Ollama / llama-server): converts the request to chat, reuses + the resilient ``create_chat_with_retries`` ladder, and re-emits the result as + Responses typed SSE events (``requests/responses.py``). + +State (``store`` / ``previous_response_id``) and background-task status live in the +router's SQLite DB (``db.py``); the router mints and owns every response id. +""" +import asyncio +import secrets +import time + +import orjson +from fastapi import APIRouter, HTTPException, Request +from starlette.responses import JSONResponse, StreamingResponse + +from cache import get_llm_cache +from config import get_config +from db import get_db +from fingerprint import _conversation_fingerprint +from state import token_queue, default_headers +from backends.normalize import is_ext_openai_endpoint +from backends.sessions import _make_openai_client +from routing import choose_endpoint, decrement_usage +from api.openai import create_chat_with_retries +from requests.responses import ( + ChatToResponsesStream, + build_response_object, + chat_message_to_output_items, + messages_to_responses_input, + responses_input_to_messages, + responses_object_to_sse, + tools_responses_to_chat, + usage_chat_to_responses, +) + +router = APIRouter() + +# In-memory handles for background tasks so /cancel can reach a running task in +# this worker. Cross-worker cancel falls back to marking the DB row cancelled. +_background_tasks: dict[str, asyncio.Task] = {} + + +# --------------------------------------------------------------------------- +# small helpers +# --------------------------------------------------------------------------- +def _usage_tokens(usage): + """Return ``(prompt, completion)`` tokens from a chat- or responses-shaped usage.""" + if not usage: + return 0, 0 + if "input_tokens" in usage: + return usage.get("input_tokens", 0) or 0, usage.get("output_tokens", 0) or 0 + return usage.get("prompt_tokens", 0) or 0, usage.get("completion_tokens", 0) or 0 + + +def _text_format_to_response_format(text): + """Map Responses ``text.format`` → Chat Completions ``response_format`` (best effort).""" + if not isinstance(text, dict): + return None + fmt = text.get("format") + if not isinstance(fmt, dict): + return None + ftype = fmt.get("type") + if ftype == "json_object": + return {"type": "json_object"} + if ftype == "json_schema": + return {"type": "json_schema", "json_schema": { + k: fmt[k] for k in ("name", "schema", "strict", "description") if k in fmt + }} + return None + + +def _native_usage_from_response(data): + return data.get("usage") + + +async def _resolve_history_messages(previous_response_id): + """Rebuild prior-turn chat messages from the stored response chain.""" + if not previous_response_id: + return [] + db = get_db() + chain = await db.get_response_chain(previous_response_id) + messages = [] + for turn in chain: + # Each turn stored the chat messages that produced it + its output items. + for m in turn.get("input_messages") or []: + messages.append(m) + for item in turn.get("output_items") or []: + if item.get("type") == "message": + text = "".join( + p.get("text", "") for p in item.get("content") or [] + if p.get("type") == "output_text" + ) + if text: + messages.append({"role": "assistant", "content": text}) + elif item.get("type") == "function_call": + messages.append({ + "role": "assistant", "content": None, + "tool_calls": [{"id": item.get("call_id"), "type": "function", + "function": {"name": item.get("name"), + "arguments": item.get("arguments", "")}}], + }) + return messages + + +class _NativeStream: + """Re-emit an SDK Responses event stream, rewriting the response id and + capturing the final output/usage for storage.""" + + def __init__(self, response_id): + self.response_id = response_id + self.output_items = [] + self.usage = None + + async def events(self, sdk_gen): + async for event in sdk_gen: + data = event.model_dump() if hasattr(event, "model_dump") else event + etype = data.get("type", "") + resp = data.get("response") + if isinstance(resp, dict) and resp.get("id"): + resp["id"] = self.response_id + if etype in ("response.completed", "response.incomplete", "response.failed") \ + and isinstance(resp, dict): + self.output_items = resp.get("output", []) or [] + self.usage = resp.get("usage") + yield f"event: {etype}\ndata: {orjson.dumps(data).decode('utf-8')}\n\n".encode("utf-8") + + +# --------------------------------------------------------------------------- +# backend execution (non-streaming, used by background + non-stream sync) +# --------------------------------------------------------------------------- +async def _run_to_completion(*, native, oclient, endpoint, model, tracking_model, + send_params, native_params): + """Drive the backend to completion (no client streaming). + + Returns ``(output_items, usage)`` where usage is responses-shaped. Caller is + responsible for ``decrement_usage`` (translated failures self-decrement inside + ``create_chat_with_retries``).""" + if native: + resp_obj = await oclient.responses.create(stream=False, **native_params) + data = resp_obj.model_dump() + return data.get("output", []) or [], data.get("usage") + async_gen = await create_chat_with_retries(oclient, {**send_params, "stream": False}, + endpoint, model, tracking_model) + message = async_gen.choices[0].message.model_dump() if async_gen.choices else {} + output_items = chat_message_to_output_items(message) + usage = usage_chat_to_responses( + async_gen.usage.model_dump() if async_gen.usage is not None else None + ) + return output_items, usage + + +# --------------------------------------------------------------------------- +# POST /v1/responses +# --------------------------------------------------------------------------- +@router.post("/v1/responses") +async def openai_responses_proxy(request: Request): + config = get_config() + try: + payload = orjson.loads((await request.body()).decode("utf-8")) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + model = payload.get("model") + input_data = payload.get("input") + instructions = payload.get("instructions") + stream = bool(payload.get("stream")) + store = payload.get("store", True) + background = bool(payload.get("background")) + previous_response_id = payload.get("previous_response_id") + tools = payload.get("tools") + metadata = payload.get("metadata") or {} + _cache_enabled = payload.get("nomyo", {}).get("cache", False) + + if not model: + raise HTTPException(status_code=400, detail="Missing required field 'model'") + if input_data is None: + raise HTTPException(status_code=400, detail="Missing required field 'input'") + if background and not store: + raise HTTPException(status_code=400, detail="background mode requires store=true") + + if ":latest" in model: + model = model.split(":latest")[0] + + # Resolve conversation: prior turns (from store) + this turn's input. + history = await _resolve_history_messages(previous_response_id) + messages = history + responses_input_to_messages(input_data, instructions) + + response_id = f"resp_{secrets.token_hex(24)}" + created_at = int(time.time()) + + # Cache lookup (foreground only) — before endpoint selection. + _cache = get_llm_cache() + if _cache is not None and _cache_enabled and not background: + cached = await _cache.get_chat("openai_responses", model, messages) + if cached is not None: + resp_obj = orjson.loads(cached) + resp_obj["id"] = response_id + if stream: + async def _served_cached(): + yield responses_object_to_sse(resp_obj) + return StreamingResponse(_served_cached(), media_type="text/event-stream") + return JSONResponse(content=resp_obj) + + # Endpoint selection (reserves a slot — must be released exactly once). + _affinity_key = _conversation_fingerprint(model, messages, None) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) + oclient = _make_openai_client(endpoint, default_headers=default_headers, + api_key=config.api_keys.get(endpoint, "no-key")) + native = is_ext_openai_endpoint(endpoint) + + # Build backend params for both shapes. + send_params = {"messages": messages, "model": model} + _opt = { + "temperature": payload.get("temperature"), + "top_p": payload.get("top_p"), + "max_tokens": payload.get("max_output_tokens"), + "tools": tools_responses_to_chat(tools), + "tool_choice": payload.get("tool_choice"), + "response_format": _text_format_to_response_format(payload.get("text")), + } + send_params.update({k: v for k, v in _opt.items() if v is not None}) + + native_instructions, native_input = messages_to_responses_input(messages) + native_params = {"model": model, "input": native_input, "store": False} + _nopt = { + "instructions": native_instructions, + "temperature": payload.get("temperature"), + "top_p": payload.get("top_p"), + "max_output_tokens": payload.get("max_output_tokens"), + "tools": tools, + "tool_choice": payload.get("tool_choice"), + "text": payload.get("text"), + "reasoning": payload.get("reasoning"), + } + native_params.update({k: v for k, v in _nopt.items() if v is not None}) + + async def _persist(status, output_items=None, usage=None, error=None, insert=False): + if not store: + return + db = get_db() + if insert: + await db.store_response( + response_id, previous_response_id=previous_response_id, model=model, + status=status, created_at=created_at, input_messages=messages, + output_items=output_items, usage=usage, instructions=instructions, error=error) + else: + await db.update_response_status(response_id, status, output_items=output_items, + usage=usage, error=error) + + async def _track(usage): + prompt_tok, comp_tok = _usage_tokens(usage) + if prompt_tok or comp_tok: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + + async def _cache_store(output_items, usage): + if _cache is None or not _cache_enabled or not output_items: + return + obj = build_response_object(response_id=response_id, model=model, + output_items=output_items, usage=usage, + created_at=created_at, + previous_response_id=previous_response_id, + instructions=instructions, metadata=metadata) + try: + await _cache.set_chat("openai_responses", model, messages, orjson.dumps(obj)) + except Exception as _ce: + print(f"[cache] set_chat (openai_responses) failed: {_ce}") + + # ---- background: run detached, return queued immediately -------------- + if background: + await _persist("queued", insert=True) + + async def _bg_run(): + try: + await get_db().update_response_status(response_id, "in_progress") + output_items, usage = await _run_to_completion( + native=native, oclient=oclient, endpoint=endpoint, model=model, + tracking_model=tracking_model, send_params=send_params, + native_params=native_params) + await _track(usage) + await _persist("completed", output_items=output_items, usage=usage) + await _cache_store(output_items, usage) + except asyncio.CancelledError: + await get_db().update_response_status(response_id, "cancelled") + raise + except Exception as e: + await get_db().update_response_status( + response_id, "failed", + error={"message": str(e)[:500], "type": type(e).__name__}) + finally: + await decrement_usage(endpoint, tracking_model) + _background_tasks.pop(response_id, None) + + task = asyncio.create_task(_bg_run()) + _background_tasks[response_id] = task + queued = build_response_object(response_id=response_id, model=model, output_items=[], + status="queued", created_at=created_at, + previous_response_id=previous_response_id, + instructions=instructions, metadata=metadata) + return JSONResponse(content=queued, status_code=200) + + # ---- streaming sync ---------------------------------------------------- + if stream: + if native: + source = await oclient.responses.create(stream=True, **native_params) + translator = _NativeStream(response_id) + else: + source = await create_chat_with_retries( + oclient, {**send_params, "stream": True, + "stream_options": {"include_usage": True}}, + endpoint, model, tracking_model) + translator = ChatToResponsesStream( + response_id, model, created_at=created_at, + previous_response_id=previous_response_id, instructions=instructions, + metadata=metadata) + + async def _stream(): + await _persist("in_progress", insert=True) + try: + async for sse in translator.events(source): + yield sse + await _track(translator.usage) + await _persist("completed", output_items=translator.output_items, + usage=translator.usage) + await _cache_store(translator.output_items, translator.usage) + finally: + await decrement_usage(endpoint, tracking_model) + + return StreamingResponse(_stream(), media_type="text/event-stream") + + # ---- non-streaming sync ------------------------------------------------ + try: + output_items, usage = await _run_to_completion( + native=native, oclient=oclient, endpoint=endpoint, model=model, + tracking_model=tracking_model, send_params=send_params, + native_params=native_params) + await _track(usage) + await _persist("completed", output_items=output_items, usage=usage, insert=True) + await _cache_store(output_items, usage) + finally: + await decrement_usage(endpoint, tracking_model) + + resp_obj = build_response_object( + response_id=response_id, model=model, output_items=output_items, usage=usage, + created_at=created_at, previous_response_id=previous_response_id, + instructions=instructions, metadata=metadata) + return JSONResponse(content=resp_obj) + + +# --------------------------------------------------------------------------- +# GET / DELETE / cancel +# --------------------------------------------------------------------------- +def _stored_to_response_object(row): + return build_response_object( + response_id=row["response_id"], model=row.get("model"), + output_items=row.get("output_items") or [], usage=row.get("usage"), + status=row.get("status") or "completed", created_at=row.get("created_at"), + previous_response_id=row.get("previous_response_id"), + instructions=row.get("instructions"), error=row.get("error")) + + +@router.get("/v1/responses/{response_id}") +async def get_response(response_id: str): + row = await get_db().get_response(response_id) + if row is None: + raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found") + return JSONResponse(content=_stored_to_response_object(row)) + + +@router.delete("/v1/responses/{response_id}") +async def delete_response(response_id: str): + deleted = await get_db().delete_response(response_id) + if not deleted: + raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found") + return JSONResponse(content={"id": response_id, "object": "response.deleted", "deleted": True}) + + +@router.post("/v1/responses/{response_id}/cancel") +async def cancel_response(response_id: str): + row = await get_db().get_response(response_id) + if row is None: + raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found") + # Cancel the running task if it lives in this worker; otherwise just mark the + # DB row so a polling client sees a terminal state (cross-worker limitation). + task = _background_tasks.get(response_id) + if task is not None and not task.done(): + task.cancel() + elif row.get("status") in ("queued", "in_progress"): + await get_db().update_response_status(response_id, "cancelled") + row = await get_db().get_response(response_id) + return JSONResponse(content=_stored_to_response_object(row)) diff --git a/db.py b/db.py index 3621144..03a5393 100644 --- a/db.py +++ b/db.py @@ -1,4 +1,4 @@ -import aiosqlite, asyncio +import aiosqlite, asyncio, orjson from typing import Optional from pathlib import Path from datetime import datetime, timezone @@ -75,6 +75,24 @@ class TokenDatabase: ''') await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_timestamp ON token_time_series(timestamp)') await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_model_ts ON token_time_series(model, timestamp)') + # Responses API state — the router owns conversation state for the + # /v1/responses family (store / previous_response_id) and tracks + # background-task status here so polling survives across workers. + await db.execute(''' + CREATE TABLE IF NOT EXISTS stored_responses ( + response_id TEXT PRIMARY KEY, + previous_response_id TEXT, + model TEXT, + status TEXT, + created_at INTEGER, + input_messages TEXT, + output_items TEXT, + usage TEXT, + instructions TEXT, + error TEXT + ) + ''') + await db.execute('CREATE INDEX IF NOT EXISTS idx_stored_responses_prev ON stored_responses(previous_response_id)') await db.commit() async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int): @@ -319,3 +337,158 @@ class TokenDatabase: await db.commit() return aggregated_count + + # ----------------------------------------------------------------- + # Responses API state (store / previous_response_id / background) + # ----------------------------------------------------------------- + @staticmethod + def _row_to_response(row) -> dict: + """Map a stored_responses row to a plain dict, decoding JSON columns.""" + def _loads(val): + if val is None: + return None + try: + return orjson.loads(val) + except (orjson.JSONDecodeError, TypeError): + return None + return { + 'response_id': row[0], + 'previous_response_id': row[1], + 'model': row[2], + 'status': row[3], + 'created_at': row[4], + 'input_messages': _loads(row[5]), + 'output_items': _loads(row[6]), + 'usage': _loads(row[7]), + 'instructions': row[8], + 'error': _loads(row[9]), + } + + async def store_response( + self, + response_id: str, + *, + previous_response_id: Optional[str], + model: str, + status: str, + created_at: int, + input_messages: list, + output_items: Optional[list] = None, + usage: Optional[dict] = None, + instructions: Optional[str] = None, + error: Optional[dict] = None, + ): + """Insert or replace a stored Responses-API response row.""" + db = await self._get_connection() + async with self._operation_lock: + await db.execute(''' + INSERT INTO stored_responses + (response_id, previous_response_id, model, status, created_at, + input_messages, output_items, usage, instructions, error) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (response_id) DO UPDATE SET + previous_response_id = excluded.previous_response_id, + model = excluded.model, + status = excluded.status, + created_at = excluded.created_at, + input_messages = excluded.input_messages, + output_items = excluded.output_items, + usage = excluded.usage, + instructions = excluded.instructions, + error = excluded.error + ''', ( + response_id, previous_response_id, model, status, created_at, + orjson.dumps(input_messages).decode("utf-8"), + orjson.dumps(output_items).decode("utf-8") if output_items is not None else None, + orjson.dumps(usage).decode("utf-8") if usage is not None else None, + instructions, + orjson.dumps(error).decode("utf-8") if error is not None else None, + )) + await db.commit() + + async def update_response_status( + self, + response_id: str, + status: str, + *, + output_items: Optional[list] = None, + usage: Optional[dict] = None, + error: Optional[dict] = None, + ): + """Update the status (and optionally output/usage/error) of a stored response.""" + db = await self._get_connection() + async with self._operation_lock: + await db.execute(''' + UPDATE stored_responses + SET status = ?, + output_items = COALESCE(?, output_items), + usage = COALESCE(?, usage), + error = COALESCE(?, error) + WHERE response_id = ? + ''', ( + status, + orjson.dumps(output_items).decode("utf-8") if output_items is not None else None, + orjson.dumps(usage).decode("utf-8") if usage is not None else None, + orjson.dumps(error).decode("utf-8") if error is not None else None, + response_id, + )) + await db.commit() + + async def get_response(self, response_id: str) -> Optional[dict]: + """Return a stored response as a dict, or None if not found.""" + db = await self._get_connection() + async with self._operation_lock: + async with db.execute(''' + SELECT response_id, previous_response_id, model, status, created_at, + input_messages, output_items, usage, instructions, error + FROM stored_responses WHERE response_id = ? + ''', (response_id,)) as cursor: + row = await cursor.fetchone() + return self._row_to_response(row) if row is not None else None + + async def delete_response(self, response_id: str) -> bool: + """Delete a stored response. Returns True if a row was removed.""" + db = await self._get_connection() + async with self._operation_lock: + cursor = await db.execute( + 'DELETE FROM stored_responses WHERE response_id = ?', (response_id,) + ) + await db.commit() + return cursor.rowcount > 0 + + async def get_response_chain(self, response_id: str, max_turns: int = 50) -> list: + """Walk previous_response_id back to the root, returned oldest-first. + + Bounded to ``max_turns`` so a pathological chain cannot stall a request. + Missing links terminate the walk gracefully. + """ + chain: list = [] + seen: set = set() + current = response_id + while current and current not in seen and len(chain) < max_turns: + seen.add(current) + resp = await self.get_response(current) + if resp is None: + break + chain.append(resp) + current = resp.get('previous_response_id') + chain.reverse() + return chain + + async def fail_orphaned_responses(self) -> int: + """Mark non-terminal responses as failed (called on startup). + + A background task lives in a worker's event loop; a process restart loses + it while the DB row stays ``queued``/``in_progress`` forever. Reconcile + those to ``failed`` so polling clients get a terminal state. + """ + db = await self._get_connection() + async with self._operation_lock: + cursor = await db.execute(''' + UPDATE stored_responses + SET status = 'failed', + error = ? + WHERE status IN ('queued', 'in_progress') + ''', (orjson.dumps({"message": "Response interrupted by server restart", "type": "server_error"}).decode("utf-8"),)) + await db.commit() + return cursor.rowcount diff --git a/requests/responses.py b/requests/responses.py new file mode 100644 index 0000000..907dc42 --- /dev/null +++ b/requests/responses.py @@ -0,0 +1,492 @@ +"""Translation between the OpenAI **Responses API** and **Chat Completions**. + +The router speaks Chat Completions to every backend (Ollama, llama-server, +external OpenAI). To expose ``/v1/responses`` transparently on top of that, this +module converts in both directions: + + * request: Responses ``input`` / ``instructions`` / ``tools`` → chat ``messages`` / ``tools`` + * response: chat ``choices[0].message`` → Responses ``output`` items + * stream: chat completion deltas → Responses typed SSE events + +Pure functions / a stream-translator class — no I/O, mirroring the style of +``requests/messages.py``. The native passthrough path (external OpenAI) does not +use this module; it forwards the SDK's Responses objects directly. +""" +import secrets +import time + +import orjson + +from requests.messages import _accumulate_openai_tc_delta + + +# --------------------------------------------------------------------------- +# Request direction: Responses → Chat Completions +# --------------------------------------------------------------------------- +def _responses_content_to_chat(content): + """Convert a Responses message ``content`` into Chat Completions content. + + Collapses a single text part to a plain string (what most backends expect); + keeps a multimodal list otherwise. + """ + if content is None or isinstance(content, str): + return content + if not isinstance(content, list): + return str(content) + parts = [] + for p in content: + if not isinstance(p, dict): + parts.append({"type": "text", "text": str(p)}) + continue + ptype = p.get("type") + if ptype in ("input_text", "output_text", "text"): + parts.append({"type": "text", "text": p.get("text", "")}) + elif ptype in ("input_image", "image_url"): + url = p.get("image_url") + if isinstance(url, dict): + url = url.get("url") + if url: + parts.append({"type": "image_url", "image_url": {"url": url}}) + # input_file / refusal / reasoning parts have no chat equivalent → skip + if len(parts) == 1 and parts[0].get("type") == "text": + return parts[0]["text"] + return parts + + +def _input_item_to_message(item): + """Convert a single Responses ``input`` item to a chat message (or None).""" + if isinstance(item, str): + return {"role": "user", "content": item} + if not isinstance(item, dict): + return None + + itype = item.get("type") + + if itype == "function_call": + return { + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": item.get("call_id") or item.get("id"), + "type": "function", + "function": { + "name": item.get("name"), + "arguments": item.get("arguments", ""), + }, + }], + } + + if itype == "function_call_output": + output = item.get("output", "") + if not isinstance(output, str): + output = orjson.dumps(output).decode("utf-8") + return { + "role": "tool", + "tool_call_id": item.get("call_id") or item.get("id"), + "content": output, + } + + if itype in ("reasoning",): + # No Chat Completions equivalent — drop. + return None + + # "message" item or a bare {role, content} chat-style item + role = item.get("role") + if role is None: + return None + return {"role": role, "content": _responses_content_to_chat(item.get("content"))} + + +def responses_input_to_messages(input_data, instructions=None): + """Build a Chat Completions ``messages`` list from Responses ``input``. + + ``instructions`` becomes a leading system message; a string ``input`` becomes + a single user message; a list ``input`` is mapped item-by-item. + """ + messages = [] + if instructions: + messages.append({"role": "system", "content": instructions}) + if input_data is None: + return messages + if isinstance(input_data, str): + messages.append({"role": "user", "content": input_data}) + return messages + if isinstance(input_data, list): + for item in input_data: + msg = _input_item_to_message(item) + if msg is not None: + messages.append(msg) + return messages + + +def _chat_content_to_responses_parts(content, assistant=False): + """Convert chat message content → Responses content parts.""" + text_type = "output_text" if assistant else "input_text" + if content is None: + return [] + if isinstance(content, str): + return [{"type": text_type, "text": content}] + parts = [] + for p in content if isinstance(content, list) else []: + if not isinstance(p, dict): + parts.append({"type": text_type, "text": str(p)}) + elif p.get("type") == "text": + parts.append({"type": text_type, "text": p.get("text", "")}) + elif p.get("type") == "image_url": + url = (p.get("image_url") or {}).get("url") + if url: + parts.append({"type": "input_image", "image_url": url}) + return parts + + +def messages_to_responses_input(messages): + """Convert chat messages → ``(instructions, Responses input items)``. + + Used for the native passthrough path: history that the router has resolved in + chat-message space is re-expressed as Responses ``input``. Leading/standalone + system messages are merged into ``instructions``. + """ + instructions_parts = [] + items = [] + for m in messages: + role = m.get("role") + if role == "system": + c = m.get("content") + instructions_parts.append(c if isinstance(c, str) else orjson.dumps(c).decode("utf-8")) + continue + if role == "tool": + out = m.get("content") + if not isinstance(out, str): + out = orjson.dumps(out).decode("utf-8") + items.append({"type": "function_call_output", + "call_id": m.get("tool_call_id"), "output": out}) + continue + if role == "assistant" and m.get("tool_calls"): + for tc in m["tool_calls"]: + fn = tc.get("function", {}) + items.append({"type": "function_call", "call_id": tc.get("id"), + "name": fn.get("name"), "arguments": fn.get("arguments", "")}) + if m.get("content"): + items.append({"role": "assistant", + "content": _chat_content_to_responses_parts(m["content"], assistant=True)}) + continue + items.append({"role": role, + "content": _chat_content_to_responses_parts(m.get("content"), + assistant=(role == "assistant"))}) + instructions = "\n\n".join(p for p in instructions_parts if p) or None + return instructions, items + + +def responses_object_to_sse(resp): + """Render a *finished* Responses object as a valid SSE event stream. + + Used to serve cache/store hits to streaming clients without a backend call. + """ + seq = [-1] + + def ev(etype, payload): + seq[0] += 1 + body = {"type": etype, "sequence_number": seq[0], **payload} + return f"event: {etype}\ndata: {orjson.dumps(body).decode('utf-8')}\n\n".encode("utf-8") + + parts_out = [] + in_progress = {**resp, "status": "in_progress", "output": [], "output_text": ""} + parts_out.append(ev("response.created", {"response": in_progress})) + parts_out.append(ev("response.in_progress", {"response": in_progress})) + for oi, item in enumerate(resp.get("output", [])): + parts_out.append(ev("response.output_item.added", + {"output_index": oi, "item": {**item, "status": "in_progress"}})) + if item.get("type") == "message": + for ci, part in enumerate(item.get("content", [])): + if part.get("type") == "output_text": + iid = item.get("id") + parts_out.append(ev("response.content_part.added", { + "item_id": iid, "output_index": oi, "content_index": ci, + "part": {"type": "output_text", "text": "", "annotations": []}})) + parts_out.append(ev("response.output_text.delta", { + "item_id": iid, "output_index": oi, "content_index": ci, + "delta": part.get("text", "")})) + parts_out.append(ev("response.output_text.done", { + "item_id": iid, "output_index": oi, "content_index": ci, + "text": part.get("text", "")})) + parts_out.append(ev("response.content_part.done", { + "item_id": iid, "output_index": oi, "content_index": ci, "part": part})) + parts_out.append(ev("response.output_item.done", {"output_index": oi, "item": item})) + parts_out.append(ev("response.completed", {"response": resp})) + return b"".join(parts_out) + + +def tools_responses_to_chat(tools): + """Map Responses tool definitions (flattened) → Chat Completions (nested).""" + if not tools: + return None + out = [] + for t in tools: + if isinstance(t, dict) and t.get("type") == "function" and "function" not in t: + fn = {k: t[k] for k in ("name", "description", "parameters", "strict") if k in t} + out.append({"type": "function", "function": fn}) + else: + out.append(t) + return out + + +# --------------------------------------------------------------------------- +# Response direction: Chat Completions → Responses +# --------------------------------------------------------------------------- +def _new_id(prefix): + return f"{prefix}_{secrets.token_hex(16)}" + + +def chat_message_to_output_items(message): + """Convert an assistant chat message (dict) into Responses output items.""" + items = [] + content = message.get("content") + if content: + items.append({ + "type": "message", + "id": _new_id("msg"), + "status": "completed", + "role": "assistant", + "content": [{"type": "output_text", "text": content, "annotations": []}], + }) + for tc in message.get("tool_calls") or []: + fn = tc.get("function", {}) + items.append({ + "type": "function_call", + "id": _new_id("fc"), + "call_id": tc.get("id"), + "name": fn.get("name"), + "arguments": fn.get("arguments", ""), + "status": "completed", + }) + return items + + +def usage_chat_to_responses(usage): + """Map chat usage ``{prompt_tokens, completion_tokens}`` → Responses usage.""" + if not usage: + return None + prompt = usage.get("prompt_tokens") or 0 + completion = usage.get("completion_tokens") or 0 + return { + "input_tokens": prompt, + "output_tokens": completion, + "total_tokens": usage.get("total_tokens") or (prompt + completion), + } + + +def output_items_to_text(output_items): + """Concatenate the ``output_text`` parts of all message items.""" + chunks = [] + for item in output_items or []: + if item.get("type") != "message": + continue + for part in item.get("content") or []: + if part.get("type") == "output_text": + chunks.append(part.get("text", "")) + return "".join(chunks) + + +def build_response_object( + *, + response_id, + model, + output_items=None, + usage=None, + status="completed", + created_at=None, + previous_response_id=None, + instructions=None, + error=None, + metadata=None, +): + """Assemble a full ``object:"response"`` body for a non-streaming reply.""" + output_items = output_items or [] + return { + "id": response_id, + "object": "response", + "created_at": created_at or int(time.time()), + "status": status, + "model": model, + "output": output_items, + "output_text": output_items_to_text(output_items), + "instructions": instructions, + "previous_response_id": previous_response_id, + "usage": usage_chat_to_responses(usage) if usage and "input_tokens" not in usage else usage, + "error": error, + "metadata": metadata or {}, + } + + +# --------------------------------------------------------------------------- +# Streaming direction: Chat Completions deltas → Responses typed SSE events +# --------------------------------------------------------------------------- +class ChatToResponsesStream: + """Translate a Chat Completions streaming generator into Responses events. + + Usage:: + + translator = ChatToResponsesStream(response_id, model, created_at) + async for sse_bytes in translator.events(chat_async_gen): + yield sse_bytes + # translator.output_items / translator.usage now populated for storage + + Emits the ordered event family + ``response.created`` → ``response.in_progress`` → + (``response.output_item.added`` → ``response.content_part.added`` → + ``response.output_text.delta``* → ``response.output_text.done`` → + ``response.content_part.done`` → ``response.output_item.done``) and/or + function-call item events → ``response.completed`` (carrying usage). + """ + + def __init__(self, response_id, model, created_at=None, + previous_response_id=None, instructions=None, metadata=None): + self.response_id = response_id + self.model = model + self.created_at = created_at or int(time.time()) + self.previous_response_id = previous_response_id + self.instructions = instructions + self.metadata = metadata or {} + self.seq = -1 + self.output_items = [] + self.usage = None + + def _snapshot(self, status, output=None): + return build_response_object( + response_id=self.response_id, + model=self.model, + output_items=output if output is not None else [], + usage=self.usage, + status=status, + created_at=self.created_at, + previous_response_id=self.previous_response_id, + instructions=self.instructions, + metadata=self.metadata, + ) + + def _event(self, etype, payload): + self.seq += 1 + body = {"type": etype, "sequence_number": self.seq, **payload} + return f"event: {etype}\ndata: {orjson.dumps(body).decode('utf-8')}\n\n".encode("utf-8") + + async def events(self, async_gen): + yield self._event("response.created", {"response": self._snapshot("in_progress")}) + yield self._event("response.in_progress", {"response": self._snapshot("in_progress")}) + + next_oi = 0 + # text message state + msg_item_id = None + msg_oi = None + text_parts = [] + # function-call state, keyed by chat tool_call index + tc_state = {} # idx -> {oi, item_id, call_id, name, args} + + async for chunk in async_gen: + usage = getattr(chunk, "usage", None) + if usage is not None: + self.usage = { + "prompt_tokens": usage.prompt_tokens or 0, + "completion_tokens": usage.completion_tokens or 0, + } + choices = getattr(chunk, "choices", None) + if not choices: + continue + delta = choices[0].delta + + content_piece = getattr(delta, "content", None) + if content_piece: + if msg_item_id is None: + msg_item_id = _new_id("msg") + msg_oi = next_oi + next_oi += 1 + item = { + "id": msg_item_id, "type": "message", "status": "in_progress", + "role": "assistant", "content": [], + } + yield self._event("response.output_item.added", + {"output_index": msg_oi, "item": item}) + yield self._event("response.content_part.added", { + "item_id": msg_item_id, "output_index": msg_oi, "content_index": 0, + "part": {"type": "output_text", "text": "", "annotations": []}, + }) + text_parts.append(content_piece) + yield self._event("response.output_text.delta", { + "item_id": msg_item_id, "output_index": msg_oi, "content_index": 0, + "delta": content_piece, + }) + + for tc in getattr(delta, "tool_calls", None) or []: + idx = tc.index + fn = getattr(tc, "function", None) + if idx not in tc_state: + item_id = _new_id("fc") + state = { + "oi": next_oi, "item_id": item_id, + "call_id": getattr(tc, "id", None) or _new_id("call"), + "name": (fn.name if fn else None), "args": "", + } + next_oi += 1 + tc_state[idx] = state + yield self._event("response.output_item.added", { + "output_index": state["oi"], + "item": { + "id": item_id, "type": "function_call", "status": "in_progress", + "call_id": state["call_id"], "name": state["name"], "arguments": "", + }, + }) + else: + state = tc_state[idx] + if getattr(tc, "id", None): + state["call_id"] = tc.id + if fn and fn.name: + state["name"] = fn.name + if fn and fn.arguments: + state["args"] += fn.arguments + yield self._event("response.function_call_arguments.delta", { + "item_id": state["item_id"], "output_index": state["oi"], + "delta": fn.arguments, + }) + + # finalize message item + if msg_item_id is not None: + full_text = "".join(text_parts) + yield self._event("response.output_text.done", { + "item_id": msg_item_id, "output_index": msg_oi, "content_index": 0, + "text": full_text, + }) + done_part = {"type": "output_text", "text": full_text, "annotations": []} + yield self._event("response.content_part.done", { + "item_id": msg_item_id, "output_index": msg_oi, "content_index": 0, + "part": done_part, + }) + msg_item = { + "id": msg_item_id, "type": "message", "status": "completed", + "role": "assistant", "content": [done_part], + } + yield self._event("response.output_item.done", + {"output_index": msg_oi, "item": msg_item}) + + # finalize function-call items (in output-index order) + tc_items = {} + for idx, state in tc_state.items(): + yield self._event("response.function_call_arguments.done", { + "item_id": state["item_id"], "output_index": state["oi"], + "arguments": state["args"], + }) + fc_item = { + "id": state["item_id"], "type": "function_call", "status": "completed", + "call_id": state["call_id"], "name": state["name"], "arguments": state["args"], + } + tc_items[state["oi"]] = fc_item + yield self._event("response.output_item.done", + {"output_index": state["oi"], "item": fc_item}) + + # assemble final output items ordered by output index + ordered = [] + if msg_item_id is not None: + ordered.append((msg_oi, msg_item)) + ordered.extend(tc_items.items()) + self.output_items = [item for _, item in sorted(ordered, key=lambda kv: kv[0])] + + yield self._event("response.completed", + {"response": self._snapshot("completed", self.output_items)}) diff --git a/router.py b/router.py index 676e42b..a2f9dd8 100644 --- a/router.py +++ b/router.py @@ -290,6 +290,8 @@ from api.management import router as management_router app.include_router(management_router) from api.openai import router as openai_router app.include_router(openai_router) +from api.responses import router as responses_router +app.include_router(responses_router) from api.ollama import router as ollama_router app.include_router(ollama_router) @@ -322,6 +324,13 @@ async def startup_event() -> None: db = TokenDatabase(config.db_path) await db.init_db() + # Reconcile Responses-API background tasks lost across a restart: their + # in-memory asyncio task is gone but the DB row may still read queued / + # in_progress, so mark those failed to give polling clients a terminal state. + _orphaned = await db.fail_orphaned_responses() + if _orphaned: + print(f"[startup] Marked {_orphaned} orphaned background response(s) as failed.") + # Load existing token counts from database async for count_entry in db.load_token_counts(): endpoint = count_entry['endpoint'] diff --git a/test/test_responses.py b/test/test_responses.py new file mode 100644 index 0000000..73d3ff2 --- /dev/null +++ b/test/test_responses.py @@ -0,0 +1,460 @@ +"""Tests for the OpenAI Responses API support (api/responses.py + requests/responses.py). + +Covers the pure translation layer, the translated (Ollama-style) and native +(external-OpenAI) backend paths, conversation storage / chaining, background mode, +and the retrieve / delete / cancel routes. +""" +import asyncio +from contextlib import ExitStack, contextmanager +from types import SimpleNamespace as NS +from unittest.mock import AsyncMock, MagicMock, patch + +import orjson +import pytest + +import router +from api import responses as api_responses +from requests import responses as rt + + +# ────────────────────────────────────────────────────────────────────────────── +# Pure translation unit tests (no app / no I/O) +# ────────────────────────────────────────────────────────────────────────────── + +class TestTranslationInputToMessages: + def test_string_input(self): + msgs = rt.responses_input_to_messages("hello") + assert msgs == [{"role": "user", "content": "hello"}] + + def test_instructions_become_system(self): + msgs = rt.responses_input_to_messages("hi", instructions="be brief") + assert msgs[0] == {"role": "system", "content": "be brief"} + assert msgs[1] == {"role": "user", "content": "hi"} + + def test_item_list_text_and_image(self): + items = [{ + "type": "message", "role": "user", + "content": [ + {"type": "input_text", "text": "describe"}, + {"type": "input_image", "image_url": "http://x/y.png"}, + ], + }] + msgs = rt.responses_input_to_messages(items) + assert msgs[0]["role"] == "user" + assert msgs[0]["content"] == [ + {"type": "text", "text": "describe"}, + {"type": "image_url", "image_url": {"url": "http://x/y.png"}}, + ] + + def test_single_text_part_collapses_to_string(self): + items = [{"type": "message", "role": "user", + "content": [{"type": "input_text", "text": "yo"}]}] + assert rt.responses_input_to_messages(items)[0]["content"] == "yo" + + def test_function_call_roundtrip(self): + items = [ + {"type": "function_call", "call_id": "c1", "name": "get", "arguments": "{\"x\":1}"}, + {"type": "function_call_output", "call_id": "c1", "output": "42"}, + ] + msgs = rt.responses_input_to_messages(items) + assert msgs[0]["role"] == "assistant" + assert msgs[0]["tool_calls"][0]["id"] == "c1" + assert msgs[0]["tool_calls"][0]["function"]["name"] == "get" + assert msgs[1] == {"role": "tool", "tool_call_id": "c1", "content": "42"} + + +class TestTranslationResponseDirection: + def test_chat_message_to_output_items_text(self): + items = rt.chat_message_to_output_items({"role": "assistant", "content": "hi there"}) + assert len(items) == 1 + assert items[0]["type"] == "message" + assert items[0]["content"][0] == {"type": "output_text", "text": "hi there", "annotations": []} + + def test_chat_message_to_output_items_tool_call(self): + items = rt.chat_message_to_output_items({ + "role": "assistant", "content": None, + "tool_calls": [{"id": "c9", "function": {"name": "f", "arguments": "{}"}}], + }) + assert items[0]["type"] == "function_call" + assert items[0]["call_id"] == "c9" + assert items[0]["name"] == "f" + + def test_usage_mapping(self): + u = rt.usage_chat_to_responses({"prompt_tokens": 7, "completion_tokens": 3}) + assert u == {"input_tokens": 7, "output_tokens": 3, "total_tokens": 10} + + def test_build_response_object_output_text(self): + items = rt.chat_message_to_output_items({"role": "assistant", "content": "abc"}) + obj = rt.build_response_object(response_id="resp_1", model="m", output_items=items) + assert obj["object"] == "response" + assert obj["output_text"] == "abc" + assert obj["status"] == "completed" + + def test_tools_responses_to_chat(self): + tools = [{"type": "function", "name": "f", "description": "d", "parameters": {"type": "object"}}] + chat_tools = rt.tools_responses_to_chat(tools) + assert chat_tools == [{"type": "function", + "function": {"name": "f", "description": "d", + "parameters": {"type": "object"}}}] + + def test_messages_to_responses_input(self): + instr, items = rt.messages_to_responses_input([ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "yo"}, + ]) + assert instr == "sys" + assert items[0] == {"role": "user", "content": [{"type": "input_text", "text": "hi"}]} + assert items[1] == {"role": "assistant", "content": [{"type": "output_text", "text": "yo"}]} + + +# ────────────────────────────────────────────────────────────────────────────── +# Fakes for backend generators +# ────────────────────────────────────────────────────────────────────────────── + +def _fake_completion(content="hello world", usage=(3, 5)): + msg = MagicMock() + msg.model_dump.return_value = {"role": "assistant", "content": content} + usage_obj = MagicMock() + usage_obj.model_dump.return_value = { + "prompt_tokens": usage[0], "completion_tokens": usage[1], "total_tokens": sum(usage)} + return NS(choices=[NS(message=msg)], usage=usage_obj) + + +def _chunk(content=None, tool_calls=None): + return NS(choices=[NS(delta=NS(content=content, tool_calls=tool_calls), + finish_reason=None)], usage=None) + + +def _usage_chunk(p, c): + return NS(choices=[], usage=NS(prompt_tokens=p, completion_tokens=c)) + + +def _text_chunks(): + async def _gen(): + yield _chunk(content="Hel") + yield _chunk(content="lo") + yield _usage_chunk(3, 5) + return _gen() + + +def _toolcall_chunks(): + tc0 = NS(index=0, id="call_1", function=NS(name="lookup", arguments='{"q":')) + tc1 = NS(index=0, id=None, function=NS(name=None, arguments='"hi"}')) + + async def _gen(): + yield _chunk(tool_calls=[tc0]) + yield _chunk(tool_calls=[tc1]) + yield _usage_chunk(4, 2) + return _gen() + + +class _FakeEvent: + def __init__(self, data): + self._data = data + + def model_dump(self): + return self._data + + +def _native_event_stream(): + async def _gen(): + yield _FakeEvent({"type": "response.created", + "response": {"id": "resp_openai", "status": "in_progress", "output": []}}) + yield _FakeEvent({"type": "response.output_text.delta", + "item_id": "msg_1", "output_index": 0, "delta": "hi"}) + yield _FakeEvent({"type": "response.completed", "response": { + "id": "resp_openai", "status": "completed", + "output": [{"type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": "hi"}]}], + "usage": {"input_tokens": 2, "output_tokens": 1, "total_tokens": 3}}}) + return _gen() + + +def _sse_events(text): + """Split an SSE body into a list of (event_type, data_dict).""" + out = [] + for frame in text.strip().split("\n\n"): + if not frame.strip(): + continue + etype = data = None + for line in frame.splitlines(): + if line.startswith("event: "): + etype = line[len("event: "):] + elif line.startswith("data: "): + data = orjson.loads(line[len("data: "):]) + out.append((etype, data)) + return out + + +@contextmanager +def _enter(*cms): + """Enter a variable number of context managers (works with *unpacked tuples).""" + with ExitStack() as stack: + for cm in cms: + stack.enter_context(cm) + yield + + +def _patch_backend(native=False, endpoint="http://ollama:11434"): + """Context managers patching endpoint selection + client construction.""" + return ( + patch.object(api_responses, "choose_endpoint", + AsyncMock(return_value=(endpoint, "test-model:latest"))), + patch.object(api_responses, "decrement_usage", AsyncMock()), + patch.object(api_responses, "is_ext_openai_endpoint", return_value=native), + patch.object(api_responses, "_make_openai_client", return_value=MagicMock()), + patch.object(api_responses, "get_llm_cache", return_value=None), + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# Translated path (Ollama-style backend) +# ────────────────────────────────────────────────────────────────────────────── + +class TestTranslatedPath: + async def test_nonstream(self, client): + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_fake_completion("hello world")))): + resp = await client.post("/v1/responses", + json={"model": "test-model", "input": "hi", "store": False}) + assert resp.status_code == 200 + body = resp.json() + assert body["object"] == "response" + assert body["output_text"] == "hello world" + assert body["usage"] == {"input_tokens": 3, "output_tokens": 5, "total_tokens": 8} + assert body["id"].startswith("resp_") + + async def test_stream_event_sequence(self, client): + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_text_chunks()))): + resp = await client.post("/v1/responses", + json={"model": "test-model", "input": "hi", + "stream": True, "store": False}) + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("text/event-stream") + events = _sse_events(resp.content.decode()) + types = [e[0] for e in events] + assert types[0] == "response.created" + assert "response.output_text.delta" in types + assert types[-1] == "response.completed" + # concatenated deltas reconstruct the content + deltas = "".join(d["delta"] for t, d in events if t == "response.output_text.delta") + assert deltas == "Hello" + # completed event carries usage + completed = [d for t, d in events if t == "response.completed"][0] + assert completed["response"]["usage"]["input_tokens"] == 3 + + async def test_stream_tool_calls(self, client): + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_toolcall_chunks()))): + resp = await client.post("/v1/responses", + json={"model": "test-model", "input": "lookup hi", + "stream": True, "store": False}) + events = _sse_events(resp.content.decode()) + types = [e[0] for e in events] + assert "response.function_call_arguments.delta" in types + assert "response.function_call_arguments.done" in types + args = "".join(d["delta"] for t, d in events + if t == "response.function_call_arguments.delta") + assert args == '{"q":"hi"}' + completed = [d for t, d in events if t == "response.completed"][0] + fc = [i for i in completed["response"]["output"] if i["type"] == "function_call"][0] + assert fc["name"] == "lookup" + assert fc["arguments"] == '{"q":"hi"}' + + +# ────────────────────────────────────────────────────────────────────────────── +# Native path (external OpenAI backend) +# ────────────────────────────────────────────────────────────────────────────── + +class TestNativePath: + async def test_nonstream_passthrough_rewrites_id(self, client): + oclient = MagicMock() + resp_obj = MagicMock() + resp_obj.model_dump.return_value = { + "id": "resp_openai", "status": "completed", + "output": [{"type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": "native hi"}]}], + "usage": {"input_tokens": 2, "output_tokens": 3, "total_tokens": 5}} + oclient.responses.create = AsyncMock(return_value=resp_obj) + with (patch.object(api_responses, "choose_endpoint", + AsyncMock(return_value=("https://api.openai.com/v1", "gpt"))), + patch.object(api_responses, "decrement_usage", AsyncMock()), + patch.object(api_responses, "is_ext_openai_endpoint", return_value=True), + patch.object(api_responses, "_make_openai_client", return_value=oclient), + patch.object(api_responses, "get_llm_cache", return_value=None)): + resp = await client.post("/v1/responses", + json={"model": "gpt", "input": "hi", "store": False}) + body = resp.json() + assert body["output_text"] == "native hi" + assert body["id"].startswith("resp_") and body["id"] != "resp_openai" + # native call must not delegate state upstream + assert oclient.responses.create.call_args.kwargs["store"] is False + + async def test_stream_passthrough(self, client): + oclient = MagicMock() + oclient.responses.create = AsyncMock(return_value=_native_event_stream()) + with (patch.object(api_responses, "choose_endpoint", + AsyncMock(return_value=("https://api.openai.com/v1", "gpt"))), + patch.object(api_responses, "decrement_usage", AsyncMock()), + patch.object(api_responses, "is_ext_openai_endpoint", return_value=True), + patch.object(api_responses, "_make_openai_client", return_value=oclient), + patch.object(api_responses, "get_llm_cache", return_value=None)): + resp = await client.post("/v1/responses", + json={"model": "gpt", "input": "hi", + "stream": True, "store": False}) + events = _sse_events(resp.content.decode()) + # the completed event's response id is rewritten to the router id + completed = [d for t, d in events if t == "response.completed"][0] + assert completed["response"]["id"].startswith("resp_") + assert completed["response"]["id"] != "resp_openai" + + +# ────────────────────────────────────────────────────────────────────────────── +# Storage + chaining + retrieve/delete +# ────────────────────────────────────────────────────────────────────────────── + +class TestStorageAndChaining: + async def test_store_and_retrieve(self, client): + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_fake_completion("remembered")))): + created = await client.post("/v1/responses", + json={"model": "test-model", "input": "hi", "store": True}) + rid = created.json()["id"] + got = await client.get(f"/v1/responses/{rid}") + assert got.status_code == 200 + assert got.json()["output_text"] == "remembered" + + async def test_previous_response_id_rehydrates_history(self, client): + # First turn + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_fake_completion("turn-one")))): + first = await client.post("/v1/responses", + json={"model": "test-model", "input": "first?", "store": True}) + rid = first.json()["id"] + + # Second turn references the first — capture the messages sent to the backend + capture = AsyncMock(return_value=_fake_completion("turn-two")) + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", capture)): + await client.post("/v1/responses", + json={"model": "test-model", "input": "second?", + "previous_response_id": rid, "store": True}) + sent_messages = capture.call_args.args[1]["messages"] + contents = [m.get("content") for m in sent_messages] + assert "first?" in contents # prior user turn replayed + assert "turn-one" in contents # prior assistant turn replayed + assert "second?" in contents # current turn appended + + async def test_delete(self, client): + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_fake_completion("bye")))): + created = await client.post("/v1/responses", + json={"model": "test-model", "input": "hi", "store": True}) + rid = created.json()["id"] + deleted = await client.delete(f"/v1/responses/{rid}") + assert deleted.status_code == 200 + assert deleted.json()["deleted"] is True + assert (await client.get(f"/v1/responses/{rid}")).status_code == 404 + + async def test_retrieve_missing_404(self, client): + assert (await client.get("/v1/responses/resp_missing")).status_code == 404 + + +# ────────────────────────────────────────────────────────────────────────────── +# Background mode +# ────────────────────────────────────────────────────────────────────────────── + +class TestBackgroundMode: + async def test_background_requires_store(self, client): + resp = await client.post("/v1/responses", + json={"model": "test-model", "input": "hi", + "background": True, "store": False}) + assert resp.status_code == 400 + + async def test_background_lifecycle(self, client): + with _enter(*_patch_backend(native=False), + patch.object(api_responses, "create_chat_with_retries", + AsyncMock(return_value=_fake_completion("bg-done")))): + created = await client.post("/v1/responses", + json={"model": "test-model", "input": "hi", + "background": True, "store": True}) + assert created.status_code == 200 + assert created.json()["status"] == "queued" + rid = created.json()["id"] + # poll until terminal + status = None + for _ in range(100): + await asyncio.sleep(0.01) + got = await client.get(f"/v1/responses/{rid}") + status = got.json()["status"] + if status in ("completed", "failed", "cancelled"): + break + assert status == "completed" + assert got.json()["output_text"] == "bg-done" + + async def test_fail_orphaned_responses(self, client): + db = router.db + await db.store_response("resp_orphan", previous_response_id=None, model="m", + status="in_progress", created_at=0, input_messages=[]) + n = await db.fail_orphaned_responses() + assert n >= 1 + row = await db.get_response("resp_orphan") + assert row["status"] == "failed" + + +# ────────────────────────────────────────────────────────────────────────────── +# Cache parity +# ────────────────────────────────────────────────────────────────────────────── + +class _FakeCache: + def __init__(self, response_bytes): + self._resp = response_bytes + self.calls = [] + + async def get_chat(self, route, model, messages): + self.calls.append((route, model, messages)) + return self._resp + + +class TestCacheParity: + async def test_cache_hit_served_as_response(self, client): + cached = orjson.dumps(rt.build_response_object( + response_id="resp_cached", model="test-model", + output_items=rt.chat_message_to_output_items( + {"role": "assistant", "content": "from-cache"}))) + fake = _FakeCache(cached) + with (patch.object(api_responses, "get_llm_cache", return_value=fake), + patch.object(api_responses, "choose_endpoint", + AsyncMock(side_effect=AssertionError("backend must not be reached")))): + resp = await client.post("/v1/responses", + json={"model": "test-model", "input": "ping", + "store": False, "nomyo": {"cache": True}}) + assert resp.status_code == 200 + assert resp.json()["output_text"] == "from-cache" + assert fake.calls and fake.calls[0][0] == "openai_responses" + + async def test_cache_hit_served_as_sse(self, client): + cached = orjson.dumps(rt.build_response_object( + response_id="resp_cached", model="test-model", + output_items=rt.chat_message_to_output_items( + {"role": "assistant", "content": "from-cache"}))) + fake = _FakeCache(cached) + with (patch.object(api_responses, "get_llm_cache", return_value=fake), + patch.object(api_responses, "choose_endpoint", + AsyncMock(side_effect=AssertionError("backend must not be reached")))): + resp = await client.post("/v1/responses", + json={"model": "test-model", "input": "ping", + "stream": True, "store": False, + "nomyo": {"cache": True}}) + assert resp.headers["content-type"].startswith("text/event-stream") + events = _sse_events(resp.content.decode()) + deltas = "".join(d["delta"] for t, d in events if t == "response.output_text.delta") + assert deltas == "from-cache" From 5184123fd20d21c9f3cfc69a4a7818f9ff7ccdb0 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Sat, 13 Jun 2026 10:22:20 +0200 Subject: [PATCH 2/5] fix: improve routing logic to favour unloaded backends instead of looking at per-model load now looking at backend total load --- routing.py | 42 +++++++++++++++++++++--------------- test/test_choose_endpoint.py | 27 +++++++++++++++++++++++ 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/routing.py b/routing.py index 059cf9a..6a0e205 100644 --- a/routing.py +++ b/routing.py @@ -202,6 +202,22 @@ async def choose_endpoint(model: str, reserve: bool = True, def utilization_ratio(ep: str) -> float: return tracking_usage(ep) / get_max_connections(ep) + def total_load(ep: str) -> int: + """Sum of in-flight requests across *all* models on the endpoint.""" + return sum(usage_counts.get(ep, {}).values()) + + def pick_least_loaded(eps: list[str]) -> str: + """Pick the endpoint with the lowest total load, breaking ties at + random. Using total load (not per-model usage) for both the ranking + *and* the tie-break is what keeps a request off a backend already + busy with a *different* model — otherwise the per-model count reads + zero everywhere and the ranking gets discarded. See issue: a cold + model B would land on the backend already serving model A while + other backends sat idle.""" + min_load = min(total_load(ep) for ep in eps) + tied = [ep for ep in eps if total_load(ep) == min_load] + return random.choice(tied) + # Priority map: position in all_endpoints list (lower = higher priority) ep_priority = {ep: i for i, ep in enumerate(all_endpoints)} @@ -235,15 +251,11 @@ async def choose_endpoint(model: str, reserve: bool = True, loaded_and_free.sort(key=utilization_ratio) selected = loaded_and_free[0] else: - # Sort ascending for load balancing — all endpoints here already have the - # model loaded, so there is no model-switching cost to optimise for. - loaded_and_free.sort(key=tracking_usage) - # When all candidates are equally idle, randomise to avoid always picking - # the first entry in a stable sort. - if all(tracking_usage(ep) == 0 for ep in loaded_and_free): - selected = random.choice(loaded_and_free) - else: - selected = loaded_and_free[0] + # All endpoints here already have the model loaded, so there + # is no model-switching cost to optimise for. Pick the least + # *total*-loaded one (tie broken at random) so we steer away + # from a backend busy serving other models. + selected = pick_least_loaded(loaded_and_free) else: # 4️⃣ Endpoints among the candidates that simply have a free slot endpoints_with_free_slot = [ @@ -257,14 +269,10 @@ async def choose_endpoint(model: str, reserve: bool = True, endpoints_with_free_slot.sort(key=utilization_ratio) selected = endpoints_with_free_slot[0] else: - # Sort by total endpoint load (ascending) to prefer idle endpoints. - endpoints_with_free_slot.sort( - key=lambda ep: sum(usage_counts.get(ep, {}).values()) - ) - if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): - selected = random.choice(endpoints_with_free_slot) - else: - selected = endpoints_with_free_slot[0] + # Prefer the endpoint with the lowest *total* load so the + # cold-start cost lands on genuinely idle hardware rather + # than a backend already busy with a different model. + selected = pick_least_loaded(endpoints_with_free_slot) else: # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue) if config.priority_routing: diff --git a/test/test_choose_endpoint.py b/test/test_choose_endpoint.py index ece609a..a6a7905 100644 --- a/test/test_choose_endpoint.py +++ b/test/test_choose_endpoint.py @@ -85,6 +85,33 @@ class TestChooseEndpointBasic: ep, _ = await router.choose_endpoint("llama3.2:latest") assert ep in (EP1, EP2) + async def test_cold_model_avoids_backend_busy_with_other_model(self): + # Regression: heterogeneous cluster. A cold model B (loaded nowhere) + # must not be routed to a backend already serving a *different* model + # while other backends sit idle. The step-4 idle check used to look at + # per-model usage (zero everywhere for B) and discard the total-load + # ranking, so B could land on the busy backend at random. + cfg = _make_cfg([EP1, EP2, EP3], max_conn=4) + + async def available(ep, *_): + return {"model-a:latest", "model-b:latest"} + + # EP3 is busy with model A; EP1 and EP2 are completely idle. Model B + # is loaded nowhere. + router.usage_counts[EP3]["model-a:latest"] = 1 + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", side_effect=available), + patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())), + ): + # Run repeatedly: the busy backend must be excluded every time, + # the idle two share the load at random. + for _ in range(50): + ep, _ = await router.choose_endpoint("model-b:latest", reserve=False) + assert ep in (EP1, EP2) + assert ep != EP3 + async def test_saturated_picks_least_busy(self): cfg = _make_cfg([EP1, EP2]) cfg.max_concurrent_connections = 1 From c8da58430a912bb21d16123efc9448aa53d49776 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Sat, 13 Jun 2026 15:54:46 +0200 Subject: [PATCH 3/5] fix: logic extend on total_load AND loaded_count --- routing.py | 39 +++++++++++++++++++++++++++--------- test/test_choose_endpoint.py | 31 ++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 9 deletions(-) diff --git a/routing.py b/routing.py index 6a0e205..ecf6803 100644 --- a/routing.py +++ b/routing.py @@ -206,16 +206,37 @@ async def choose_endpoint(model: str, reserve: bool = True, """Sum of in-flight requests across *all* models on the endpoint.""" return sum(usage_counts.get(ep, {}).values()) + # How many models each candidate currently has *resident* (from the + # /api/ps probe). With infinite keep-alive a model stays loaded long + # after its in-flight count drops to zero, so this is the signal that + # spreads *distinct* models across backends. + ep_loaded_counts = { + ep: len(models) for ep, models in zip(candidate_endpoints, loaded_sets) + } + + def loaded_count(ep: str) -> int: + return ep_loaded_counts.get(ep, 0) + def pick_least_loaded(eps: list[str]) -> str: - """Pick the endpoint with the lowest total load, breaking ties at - random. Using total load (not per-model usage) for both the ranking - *and* the tie-break is what keeps a request off a backend already - busy with a *different* model — otherwise the per-model count reads - zero everywhere and the ranking gets discarded. See issue: a cold - model B would land on the backend already serving model A while - other backends sat idle.""" - min_load = min(total_load(ep) for ep in eps) - tied = [ep for ep in eps if total_load(ep) == min_load] + """Pick the least-committed endpoint, breaking ties at random. + + Ordering key is ``(total_load, loaded_count)``: + + * ``total_load`` (in-flight requests across *all* models) keeps a + request off a backend already busy with a *different* model — + otherwise the per-model count reads zero everywhere and the + ranking is discarded (cold model B landing on the box serving A). + * ``loaded_count`` (number of *resident* models) then spreads + distinct models across backends. Two different cold models (27b, + 35b) requested back-to-back must not pile onto the same box: once + 27b is resident there, that box has loaded_count 1 while the idle + backends have 0, so the next cold model prefers an empty backend + even though every backend reports zero in-flight load. + + ``random.choice`` only breaks genuine ties on both keys, so a single + idle cluster still distributes the very first cold model evenly.""" + best = min((total_load(ep), loaded_count(ep)) for ep in eps) + tied = [ep for ep in eps if (total_load(ep), loaded_count(ep)) == best] return random.choice(tied) # Priority map: position in all_endpoints list (lower = higher priority) diff --git a/test/test_choose_endpoint.py b/test/test_choose_endpoint.py index a6a7905..be75f82 100644 --- a/test/test_choose_endpoint.py +++ b/test/test_choose_endpoint.py @@ -112,6 +112,37 @@ class TestChooseEndpointBasic: assert ep in (EP1, EP2) assert ep != EP3 + async def test_two_cold_models_spread_across_backends(self): + # Regression: 3 backends all advertise all models. Two *different* + # cold models requested back-to-back must land on *different* + # backends. Once model-a is resident on the chosen backend (infinite + # keep-alive), its in-flight count drops back to 0 — so only the + # resident-model count distinguishes the backends. Without it, the + # second cold model would randomly re-collide on the busy backend. + cfg = _make_cfg([EP1, EP2, EP3], max_conn=4) + + async def available(ep, *_): + return {"model-a:latest", "model-b:latest"} + + # model-a finished loading on EP1 and stays resident; its request has + # completed so EP1 has zero in-flight load, same as EP2/EP3. + loaded = {EP1: {"model-a:latest"}, EP2: set(), EP3: set()} + + async def loaded_models(ep): + return loaded[ep] + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", side_effect=available), + patch.object(router.fetch, "loaded_models", side_effect=loaded_models), + ): + # A cold model-b must avoid EP1 (which already holds model-a) and + # go to one of the empty backends, every time. + for _ in range(50): + ep, _ = await router.choose_endpoint("model-b:latest", reserve=False) + assert ep in (EP2, EP3) + assert ep != EP1 + async def test_saturated_picks_least_busy(self): cfg = _make_cfg([EP1, EP2]) cfg.max_concurrent_connections = 1 From aa8baebac5184b8c0abaffea676802007d96dc2f Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Sun, 14 Jun 2026 16:34:31 +0200 Subject: [PATCH 4/5] feat: add llama-swap as a backend --- api/management.py | 14 +++-- api/ollama.py | 85 +++++++++++++++++++++------ api/openai.py | 95 +++++++++++++++++++++++++++--- backends/control.py | 50 ++++++++++++++++ backends/normalize.py | 39 +++++++++---- backends/probe.py | 38 +++++++++++- config.py | 4 ++ config.yaml | 13 ++++- doc/configuration.md | 31 ++++++++++ router.py | 6 +- routing.py | 8 ++- test/config_test.yaml | 4 ++ test/conftest.py | 3 + test/test_choose_endpoint.py | 24 +++++++- test/test_fetch.py | 27 ++++++++- test/test_llama_swap.py | 109 +++++++++++++++++++++++++++++++++++ test/test_unit_helpers.py | 46 +++++++++++++++ 17 files changed, 544 insertions(+), 52 deletions(-) create mode 100644 backends/control.py create mode 100644 test/test_llama_swap.py diff --git a/api/management.py b/api/management.py index ac1f356..0e9ecc2 100644 --- a/api/management.py +++ b/api/management.py @@ -27,7 +27,7 @@ from state import ( _affinity_lock, ) from sse import subscribe, unsubscribe -from backends.normalize import _normalize_llama_model_name +from backends.normalize import _normalize_llama_model_name, is_llama_server, llama_endpoints from backends.probe import _endpoint_health @@ -127,7 +127,6 @@ async def affinity_stats(request: Request): now = time.monotonic() entries: list[dict] = [] - llama_eps = set(config.llama_server_endpoints) async with _affinity_lock: for fp, (ep, mdl, expires_at) in list(_affinity_map.items()): remaining = expires_at - now @@ -136,7 +135,7 @@ async def affinity_stats(request: Request): continue # Mirror the normalisation used by /api/ps_details so the dashboard # can join affinity entries to PS rows by (endpoint, model). - display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl + display_model = _normalize_llama_model_name(mdl) if is_llama_server(ep) else mdl entries.append({ "endpoint": ep, "model": display_model, @@ -175,9 +174,12 @@ async def config_proxy(request: Request): ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints]) llama_results = [] - if config.llama_server_endpoints: + # llama-server and llama-swap render identically in the dashboard ("llama" rows), + # so health-check both and merge them into one list. + llama_eps = llama_endpoints(config) + if llama_eps: llama_results = await asyncio.gather( - *[check(ep) for ep in config.llama_server_endpoints] + *[check(ep) for ep in llama_eps] ) return { @@ -227,7 +229,7 @@ async def health_proxy(request: Request): # purposes. Probing /api/version alone would miss the case where the # Ollama process is up but /api/ps is failing — see issue #83. all_endpoints = list(config.endpoints) - llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] + llama_eps_extra = [ep for ep in llama_endpoints(config) if ep not in config.endpoints] all_endpoints += llama_eps_extra probe_results = await asyncio.gather( diff --git a/api/ollama.py b/api/ollama.py index afba243..0fc98aa 100644 --- a/api/ollama.py +++ b/api/ollama.py @@ -40,9 +40,12 @@ from backends.health import ( from backends.normalize import ( dedupe_on_keys, is_openai_compatible, + is_llama_server, + llama_endpoints, _normalize_llama_model_name, _extract_llama_quant, ) +from backends.control import unload_model from backends.probe import fetch from backends.sessions import _make_openai_client, get_ollama_client, get_probe_session from requests.chat import _make_moe_requests @@ -372,7 +375,7 @@ async def chat_proxy(request: Request): if use_openai: start_ts = time.perf_counter() # Proactive trim: only for small-ctx models we've already seen run out of space - _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model + _lookup_model = _normalize_llama_model_name(model) if is_llama_server(endpoint) else model _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: _pre_target = int((_known_nctx - _known_nctx // 4) / 1.2) @@ -935,8 +938,8 @@ async def tags_proxy(request: Request): # 1. Query all endpoints for models tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep], skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" in ep] - # Also query llama-server endpoints not already covered by config.endpoints - llama_eps_for_tags = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] + # Also query llama-server / llama-swap endpoints not already covered by config.endpoints + llama_eps_for_tags = [ep for ep in llama_endpoints(config) if ep not in config.endpoints] tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in llama_eps_for_tags] all_models = await asyncio.gather(*tasks) @@ -960,27 +963,42 @@ async def tags_proxy(request: Request): ) +async def _fetch_llama_swap_running(endpoint: str) -> list[dict]: + """Return the list of ready (`state == "ready"`) workers from a llama-swap + endpoint's `/running` route. llama-swap omits the per-model `status` field on + `/v1/models`, so running workers must be read here instead. + """ + config = get_config() + base_url = endpoint.rstrip("/").removesuffix("/v1") + return await fetch.endpoint_details( + base_url, "/running", "running", config.api_keys.get(endpoint), + skip_error_cache=True, timeout=8, + ) + + @router.get("/api/ps") async def ps_proxy(request: Request): """ - Proxy a ps request to all Ollama and llama-server endpoints and reply a unique list of all running models. + Proxy a ps request to all Ollama, llama-server and llama-swap endpoints and reply a unique list of all running models. For Ollama endpoints: queries /api/ps For llama-server endpoints: queries /v1/models with status.value == "loaded" + For llama-swap endpoints: queries /running (state == "ready") """ config = get_config() # 1. Query Ollama endpoints for running models via /api/ps ollama_tasks = [fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] # 2. Query llama-server endpoints for loaded models via /v1/models - # Also query endpoints from llama_server_endpoints that may not be in config.endpoints - all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints) llama_tasks = [ fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) - for ep in all_llama_endpoints + for ep in config.llama_server_endpoints ] + # 3. Query llama-swap endpoints for running workers via /running + swap_tasks = [_fetch_llama_swap_running(ep) for ep in config.llama_swap_endpoints] ollama_loaded = await asyncio.gather(*ollama_tasks) if ollama_tasks else [] llama_loaded = await asyncio.gather(*llama_tasks) if llama_tasks else [] + swap_running = await asyncio.gather(*swap_tasks) if swap_tasks else [] models = {'models': []} # Add Ollama models (if any) @@ -1003,6 +1021,21 @@ async def ps_proxy(request: Request): "status": item.get("status"), "details": {"quantization_level": quant} if quant else {} }) + # Add llama-swap running workers (already filtered on state == "ready") + if swap_running: + for runlist in swap_running: + for item in runlist: + if item.get("state") != "ready": + continue + raw_id = item.get("model", "") + normalized = _normalize_llama_model_name(raw_id) + quant = _extract_llama_quant(raw_id) + models['models'].append({ + "name": normalized, + "id": normalized, + "digest": "", + "details": {"quantization_level": quant} if quant else {} + }) # 3. Return a JSONResponse with deduplicated currently deployed models # Deduplicate on 'name' rather than 'digest': llama-server models always @@ -1101,16 +1134,7 @@ async def ps_details_proxy(request: Request): is_generation = "temperature" in dgs if is_sleeping: - unload_url = f"{base_url}/models/unload" - try: - async with client.post( - unload_url, - json={"model": model_id}, - headers=headers, - ) as unload_resp: - print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}") - except Exception as ue: - print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}") + await unload_model(endpoint, model_id) return n_ctx, is_sleeping, is_generation except Exception as e: @@ -1131,4 +1155,31 @@ async def ps_details_proxy(request: Request): if not is_sleeping: models.append(model_dict) + # Add llama-swap running workers (read from /running; no status/props/auto-unload — + # llama-swap omits the status field on /v1/models and manages its own TTL eviction). + if config.llama_swap_endpoints: + swap_running = await asyncio.gather( + *[_fetch_llama_swap_running(ep) for ep in config.llama_swap_endpoints] + ) + for endpoint, runlist in zip(config.llama_swap_endpoints, swap_running): + for item in runlist: + if not isinstance(item, dict) or item.get("state") != "ready": + continue + raw_id = item.get("model", "") + if not raw_id: + continue + normalized = _normalize_llama_model_name(raw_id) + quant = _extract_llama_quant(raw_id) + models.append({ + "name": normalized, + "id": normalized, + "original_name": raw_id, + "digest": "", + "details": {"quantization_level": quant} if quant else {}, + "endpoint": endpoint, + "state": item.get("state"), + "ttl": item.get("ttl"), + "proxy": item.get("proxy"), + }) + return JSONResponse(content={"models": models}, status_code=200) diff --git a/api/openai.py b/api/openai.py index ab24f54..1f0d22d 100644 --- a/api/openai.py +++ b/api/openai.py @@ -34,6 +34,8 @@ from backends.normalize import ( ep2base, is_ext_openai_endpoint, is_openai_compatible, + is_llama_server, + llama_endpoints, _normalize_llama_model_name, ) from backends.probe import fetch @@ -353,7 +355,7 @@ async def openai_chat_completions_proxy(request: Request): resolved_msgs = await _normalize_images_in_messages(params.get("messages", [])) send_params = {**params, "messages": resolved_msgs} # Proactive trim: only for small-ctx models we've already seen run out of space - _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model + _lookup_model = _normalize_llama_model_name(model) if is_llama_server(endpoint) else model _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: _pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2) @@ -658,9 +660,9 @@ async def openai_models_proxy(request: Request): ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] # 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)] - # 3. Query llama-server endpoints for loaded models via /v1/models - # Also query endpoints from llama_server_endpoints that may not be in config.endpoints - all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints) + # 3. Query llama-server / llama-swap endpoints for advertised models via /v1/models + # Also query endpoints that may not be in config.endpoints + all_llama_endpoints = llama_endpoints(config) llama_tasks = [ fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in all_llama_endpoints @@ -783,10 +785,10 @@ async def rerank_proxy(request: Request): upstream_payload[optional_key] = payload[optional_key] # Determine upstream URL: - # llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints) + # llama-server / llama-swap expose /v1/rerank (base already contains /v1) # External OpenAI endpoints expose /rerank under their /v1 base - if endpoint in config.llama_server_endpoints: - # llama-server: endpoint may or may not already contain /v1 + if is_llama_server(endpoint): + # llama-server / llama-swap: endpoint may or may not already contain /v1 if "/v1" in endpoint: rerank_url = f"{endpoint}/rerank" else: @@ -823,3 +825,82 @@ async def rerank_proxy(request: Request): return JSONResponse(content=data) finally: await decrement_usage(endpoint, tracking_model) + + +async def _resolve_llama_swap_endpoint(model_id: str) -> str | None: + """Pick the llama-swap endpoint that serves ``model_id``. + + Prefers an endpoint that already has the worker running; falls back to any + that advertises the model. Returns None if none do. + """ + config = get_config() + swap_eps = config.llama_swap_endpoints + if not swap_eps: + return None + + advertised = await asyncio.gather( + *[fetch.available_models(ep, config.api_keys.get(ep)) for ep in swap_eps] + ) + candidates = [ep for ep, models in zip(swap_eps, advertised) if model_id in models] + if not candidates: + return None + if len(candidates) == 1: + return candidates[0] + + loaded = await asyncio.gather(*[fetch.loaded_models(ep) for ep in candidates]) + for ep, lm in zip(candidates, loaded): + if model_id in lm: + return ep + return candidates[0] + + +@router.api_route("/upstream/{model_id}/{path:path}", methods=["GET", "POST"]) +async def llama_swap_upstream(model_id: str, path: str, request: Request): + """Bypass llama-swap and reach a model's underlying llama-server worker directly + via llama-swap's ``/upstream/:model_id`` route. + + Lets clients use llama-server features that llama-swap itself does not forward + (e.g. token-array prompts), while still letting the router pick the backend that + actually hosts the model. ``/upstream`` is a root route, so the ``/v1`` suffix is + stripped from the configured endpoint. + """ + config = get_config() + endpoint = await _resolve_llama_swap_endpoint(model_id) + if endpoint is None: + raise HTTPException( + status_code=404, + detail=f"No configured llama-swap endpoint serves model '{model_id}'.", + ) + + base_url = endpoint.rstrip("/").removesuffix("/v1") + url = f"{base_url}/upstream/{model_id}/{path}" + if request.url.query: + url = f"{url}?{request.url.query}" + + headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} + content_type = request.headers.get("content-type") + if content_type: + headers["Content-Type"] = content_type + api_key = config.api_keys.get(endpoint) + if api_key is not None: + headers["Authorization"] = "Bearer " + api_key + + body = await request.body() + client: aiohttp.ClientSession = get_session(endpoint) + try: + resp = await client.request(request.method, url, data=body or None, headers=headers) + except Exception as e: + raise HTTPException(status_code=502, detail=f"Upstream request to {url} failed: {e}") + + async def _iter(): + try: + async for chunk in resp.content.iter_any(): + yield chunk + finally: + resp.release() + + return StreamingResponse( + _iter(), + status_code=resp.status, + media_type=resp.headers.get("Content-Type"), + ) diff --git a/backends/control.py b/backends/control.py new file mode 100644 index 0000000..fda5fe3 --- /dev/null +++ b/backends/control.py @@ -0,0 +1,50 @@ +"""Backend control operations (model unload). + +llama-server and llama-swap evict a resident model through different routes: + * llama-server → ``POST {base}/models/unload`` with body ``{"model": id}`` + * llama-swap → ``POST {base}/api/models/unload/{id}`` (path parameter) + +``unload_model`` dispatches on the configured backend type so callers don't +have to know which one they are talking to. Both routes live at the endpoint +root, so any ``/v1`` suffix is stripped first. +""" +from typing import Optional + +import aiohttp + +from config import get_config +from state import default_headers +from backends.sessions import get_probe_session +from backends.normalize import is_llama_swap +from backends.health import _format_connection_issue + + +async def unload_model(endpoint: str, model_id: str) -> bool: + """Ask ``endpoint`` to unload ``model_id``. Returns True on a 2xx response. + + ``model_id`` must be the backend's native model identifier (the raw HF id + for llama-server / llama-swap), not the router-normalized display name. + """ + cfg = get_config() + base_url = endpoint.rstrip("/").removesuffix("/v1") + headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} + api_key: Optional[str] = cfg.api_keys.get(endpoint) + if api_key is not None: + headers["Authorization"] = "Bearer " + api_key + + if is_llama_swap(endpoint): + url = f"{base_url}/api/models/unload/{model_id}" + json_body = None + else: + url = f"{base_url}/models/unload" + json_body = {"model": model_id} + + client: aiohttp.ClientSession = get_probe_session(endpoint) + try: + async with client.post(url, json=json_body, headers=headers) as resp: + ok = resp.status < 400 + print(f"[unload_model] {model_id} on {endpoint}: {resp.status}") + return ok + except Exception as e: + print(f"[unload_model] {_format_connection_issue(url, e)}") + return False diff --git a/backends/normalize.py b/backends/normalize.py index 6603f9d..41fc199 100644 --- a/backends/normalize.py +++ b/backends/normalize.py @@ -50,27 +50,46 @@ def dedupe_on_keys(dicts, key_fields): return out +def is_llama_swap(endpoint: str) -> bool: + """True if the endpoint is a configured llama-swap front.""" + return endpoint in get_config().llama_swap_endpoints + + +def is_llama_server(endpoint: str) -> bool: + """True for a llama.cpp llama-server OR a llama-swap front. + + Both speak the same OpenAI-compatible surface, so the router treats them + identically everywhere except loaded-model detection and model unload. + """ + cfg = get_config() + return endpoint in cfg.llama_server_endpoints or endpoint in cfg.llama_swap_endpoints + + +def llama_endpoints(cfg) -> list: + """Combined, de-duplicated llama-server + llama-swap endpoints (order preserved).""" + return list(dict.fromkeys([*cfg.llama_server_endpoints, *cfg.llama_swap_endpoints])) + + def is_ext_openai_endpoint(endpoint: str) -> bool: """ - Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama or llama-server). + Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama, llama-server or llama-swap). Returns True for: - External services like OpenAI.com, Groq, etc. Returns False for: - Ollama endpoints (without /v1, or with /v1 but default port 11434) - - llama-server endpoints (explicitly configured in llama_server_endpoints) + - llama-server / llama-swap endpoints (explicitly configured) """ - cfg = get_config() - # Check if it's a llama-server endpoint (has /v1 and is in the configured list) - if endpoint in cfg.llama_server_endpoints: + # Check if it's a llama-server / llama-swap endpoint (has /v1 and is in a configured list) + if is_llama_server(endpoint): return False if "/v1" not in endpoint: return False base_endpoint = endpoint.replace('/v1', '') - if base_endpoint in cfg.endpoints: + if base_endpoint in get_config().endpoints: return False # It's Ollama's /v1 # Check for default Ollama port @@ -83,9 +102,9 @@ def is_ext_openai_endpoint(endpoint: str) -> bool: def is_openai_compatible(endpoint: str) -> bool: """ Return True if the endpoint speaks the OpenAI API (not native Ollama). - This includes external OpenAI endpoints AND llama-server endpoints. + This includes external OpenAI endpoints AND llama-server / llama-swap endpoints. """ - return "/v1" in endpoint or endpoint in get_config().llama_server_endpoints + return "/v1" in endpoint or is_llama_server(endpoint) def get_tracking_model(endpoint: str, model: str) -> str: @@ -102,8 +121,8 @@ def get_tracking_model(endpoint: str, model: str) -> str: if is_ext_openai_endpoint(endpoint): return model - # llama-server endpoints use normalized names in PS - if endpoint in get_config().llama_server_endpoints: + # llama-server / llama-swap endpoints use normalized names in PS + if is_llama_server(endpoint): return _normalize_llama_model_name(model) # Ollama endpoints: append ":latest" if no version suffix diff --git a/backends/probe.py b/backends/probe.py index 3ce089f..f59e65e 100644 --- a/backends/probe.py +++ b/backends/probe.py @@ -46,7 +46,7 @@ from backends.health import ( _format_connection_issue, _is_llama_model_loaded, ) -from backends.normalize import is_ext_openai_endpoint, is_openai_compatible +from backends.normalize import is_ext_openai_endpoint, is_openai_compatible, is_llama_server, is_llama_swap class fetch: @@ -61,10 +61,10 @@ class fetch: headers["Authorization"] = "Bearer " + api_key ep_base = endpoint.rstrip("/") - if endpoint in cfg.llama_server_endpoints and "/v1" not in endpoint: + if is_llama_server(endpoint) and "/v1" not in endpoint: endpoint_url = f"{ep_base}/v1/models" key = "data" - elif "/v1" in endpoint or endpoint in cfg.llama_server_endpoints: + elif "/v1" in endpoint or is_llama_server(endpoint): endpoint_url = f"{ep_base}/models" key = "data" else: @@ -194,6 +194,38 @@ class fetch: client: aiohttp.ClientSession = get_probe_session(endpoint) cfg = get_config() + # llama-swap: loaded/running workers are reported at /running (state == "ready"), + # NOT via a status field on /v1/models (which it omits). /running is a root route, + # so strip any /v1 suffix from the configured endpoint. + if is_llama_swap(endpoint): + base_url = endpoint.rstrip("/").removesuffix("/v1") + headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} + api_key = cfg.api_keys.get(endpoint) + if api_key is not None: + headers["Authorization"] = "Bearer " + api_key + try: + async with client.get(f"{base_url}/running", headers=headers) as resp: + await _ensure_success(resp) + data = await resp.json() + + models = { + item.get("model") + for item in data.get("running", []) + if item.get("model") and item.get("state") == "ready" + } + + async with _loaded_models_cache_lock: + _loaded_models_cache[endpoint] = (models, time.time()) + async with _loaded_error_cache_lock: + _loaded_error_cache.pop(endpoint, None) + return models + except Exception as e: + message = _format_connection_issue(f"{base_url}/running", e) + print(f"[fetch.loaded_models] {message}") + async with _loaded_error_cache_lock: + _loaded_error_cache[endpoint] = time.time() + return set() + # Check if this is a llama-server endpoint if endpoint in cfg.llama_server_endpoints: # Query /v1/models for llama-server. Send the configured key as a diff --git a/config.py b/config.py index 143a2f9..03d8e94 100644 --- a/config.py +++ b/config.py @@ -23,6 +23,10 @@ class Config(BaseSettings): ) # List of llama-server endpoints (OpenAI-compatible with /v1/models status info) llama_server_endpoints: List[str] = Field(default_factory=list) + # List of llama-swap endpoints (OpenAI-compatible front for multiple llama-server + # workers). Same surface as llama_server_endpoints, but loaded models are read from + # /running (not /v1/models status) and unload uses POST /api/models/unload/:model_id. + llama_swap_endpoints: List[str] = Field(default_factory=list) # Max concurrent connections per endpoint‑model pair, see OLLAMA_NUM_PARALLEL max_concurrent_connections: int = 1 # Per-endpoint overrides: {endpoint_url: {max_concurrent_connections: N}} diff --git a/config.yaml b/config.yaml index 2107a3c..51ebb1b 100644 --- a/config.yaml +++ b/config.yaml @@ -6,7 +6,15 @@ endpoints: - https://api.openai.com/v1 llama_server_endpoints: - - http://192.168.0.50:8889/v1 + - http://192.168.0.51:8889/v1 + +# llama-swap endpoints (OpenAI-compatible front for multiple llama-server workers). +# Same surface as llama_server_endpoints, but the router reads loaded/running workers +# from /running (state == "ready") instead of a /v1/models status field, and unloads via +# POST /api/models/unload/:model_id. The router also exposes /upstream/:model_id/ +# to bypass llama-swap and reach a model's underlying llama-server worker directly. +llama_swap_endpoints: + - http://192.168.0.52:8890/v1 # Maximum concurrent connections *per endpoint‑model pair* (equals to OLLAMA_NUM_PARALLEL) # This is the global default; individual endpoints can override it via endpoint_config below. @@ -57,7 +65,8 @@ api_keys: "http://192.168.0.51:11434": "ollama" "http://192.168.0.52:11434": "ollama" "https://api.openai.com/v1": "${OPENAI_KEY}" - "http://192.168.0.50:8889/v1": "llama" + "http://192.168.0.51:8889/v1": "llama" + "http://192.168.0.52:8889/v1": "llama-swap" # ------------------------------------------------------------- # Semantic LLM Cache (optional — disabled by default) diff --git a/doc/configuration.md b/doc/configuration.md index 1addd66..e067207 100644 --- a/doc/configuration.md +++ b/doc/configuration.md @@ -78,6 +78,37 @@ endpoints: - OpenAI-compatible endpoints use `/v1` prefix - The router automatically detects endpoint type based on URL pattern +### `llama_server_endpoints` + +**Type**: `list[str]` (optional) + +**Default**: `[]` + +**Description**: List of [llama.cpp `llama-server`](https://github.com/ggml-org/llama.cpp) endpoints (OpenAI-compatible, configured with the `/v1` suffix). The router reads each backend's loaded models from `/v1/models` (entries with `status == "loaded"`) and unloads idle models via `POST /models/unload`. + +```yaml +llama_server_endpoints: + - http://192.168.0.50:8889/v1 +``` + +### `llama_swap_endpoints` + +**Type**: `list[str]` (optional) + +**Default**: `[]` + +**Description**: List of [llama-swap](https://github.com/mostlygeek/llama-swap) endpoints (OpenAI-compatible, configured with the `/v1` suffix). llama-swap fronts multiple `llama-server` workers behind one address. It is treated like `llama_server_endpoints` for routing, model discovery, and reranking, but differs in two ways the router handles automatically: + +- **Loaded-model detection** — llama-swap's `/v1/models` omits the per-model `status` field, so running workers are read from `GET /running` (entries with `state == "ready"`). +- **Model unload** — done via `POST /api/models/unload/:model_id` (path parameter), not the `llama-server` body form. + +The router also exposes a passthrough route, `GET|POST /upstream/:model_id/`, which forwards directly to a model's underlying `llama-server` worker (via llama-swap's `/upstream`), letting clients use `llama-server` features that llama-swap does not forward (e.g. token-array prompts). + +```yaml +llama_swap_endpoints: + - http://192.168.0.50:8890/v1 +``` + ### `max_concurrent_connections` **Type**: `int` diff --git a/router.py b/router.py index a2f9dd8..aca2d01 100644 --- a/router.py +++ b/router.py @@ -231,6 +231,7 @@ from backends.health import ( from backends.normalize import ( is_ext_openai_endpoint, is_openai_compatible, + llama_endpoints, get_tracking_model, ) @@ -310,6 +311,7 @@ async def startup_event() -> None: f"Loaded configuration from {config_path}:\n" f" endpoints={config.endpoints},\n" f" llama_server_endpoints={config.llama_server_endpoints},\n" + f" llama_swap_endpoints={config.llama_swap_endpoints},\n" f" max_concurrent_connections={config.max_concurrent_connections},\n" f" endpoint_config={config.endpoint_config},\n" f" priority_routing={config.priority_routing}" @@ -374,7 +376,7 @@ async def startup_event() -> None: app_state["httpx_clients"][ep] = httpx.AsyncClient(timeout=30.0) # Create per-endpoint Unix socket sessions for .sock endpoints - for ep in config.llama_server_endpoints: + for ep in llama_endpoints(config): if _is_unix_socket_endpoint(ep): sock_path = _get_socket_path(ep) sock_connector = aiohttp.UnixConnector(path=sock_path) @@ -391,7 +393,7 @@ async def startup_event() -> None: # client (/api/chat, /api/generate) and the OpenAI client (/v1/* routes), # so warm both; OpenAI-compatible endpoints only need the OpenAI client. _warm_endpoints = config.endpoints + [ - ep for ep in config.llama_server_endpoints if ep not in config.endpoints + ep for ep in llama_endpoints(config) if ep not in config.endpoints ] for ep in _warm_endpoints: try: diff --git a/routing.py b/routing.py index ecf6803..0a1cc7f 100644 --- a/routing.py +++ b/routing.py @@ -32,6 +32,8 @@ from backends.health import _is_fresh from backends.normalize import ( is_ext_openai_endpoint, is_openai_compatible, + is_llama_server, + llama_endpoints, get_tracking_model, ) from backends.probe import fetch @@ -93,8 +95,8 @@ async def choose_endpoint(model: str, reserve: bool = True, """ config = get_config() # 1️⃣ Gather advertised‑model sets for all endpoints concurrently - # Include both config.endpoints and config.llama_server_endpoints - llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] + # Include config.endpoints plus any llama-server / llama-swap endpoints + llama_eps_extra = [ep for ep in llama_endpoints(config) if ep not in config.endpoints] all_endpoints = config.endpoints + llama_eps_extra tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if not is_openai_compatible(ep)] @@ -114,7 +116,7 @@ async def choose_endpoint(model: str, reserve: bool = True, model_without_latest = model.split(":latest")[0] candidate_endpoints = [ ep for ep, models in zip(all_endpoints, advertised_sets) - if model_without_latest in models and (is_ext_openai_endpoint(ep) or ep in config.llama_server_endpoints) + if model_without_latest in models and (is_ext_openai_endpoint(ep) or is_llama_server(ep)) ] if not candidate_endpoints: # Only add :latest suffix if model doesn't already have a version suffix diff --git a/test/config_test.yaml b/test/config_test.yaml index 30f2fa3..ed96542 100644 --- a/test/config_test.yaml +++ b/test/config_test.yaml @@ -4,10 +4,14 @@ endpoints: llama_server_endpoints: - http://192.168.0.51:12434/v1 +llama_swap_endpoints: + - http://192.168.0.51:12435/v1 + max_concurrent_connections: 2 api_keys: "http://192.168.0.51:12434": "ollama" "http://192.168.0.51:12434/v1": "llama" + "http://192.168.0.51:12435/v1": "llama-swap" cache_enabled: false diff --git a/test/conftest.py b/test/conftest.py index c5142da..da7dacf 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -57,6 +57,7 @@ def mock_config(): cfg = MagicMock() cfg.endpoints = [TEST_OLLAMA] cfg.llama_server_endpoints = [TEST_LLAMA] + cfg.llama_swap_endpoints = [] cfg.api_keys = {TEST_OLLAMA: "ollama", TEST_LLAMA: "llama"} cfg.max_concurrent_connections = 2 cfg.router_api_key = None @@ -70,6 +71,7 @@ def mock_config_no_llama(): cfg = MagicMock() cfg.endpoints = [TEST_OLLAMA] cfg.llama_server_endpoints = [] + cfg.llama_swap_endpoints = [] cfg.api_keys = {TEST_OLLAMA: "ollama"} cfg.max_concurrent_connections = 2 cfg.router_api_key = None @@ -83,6 +85,7 @@ def mock_config_with_key(): cfg = MagicMock() cfg.endpoints = [TEST_OLLAMA] cfg.llama_server_endpoints = [] + cfg.llama_swap_endpoints = [] cfg.api_keys = {} cfg.max_concurrent_connections = 2 cfg.router_api_key = "test-secret-key" diff --git a/test/test_choose_endpoint.py b/test/test_choose_endpoint.py index be75f82..17650c4 100644 --- a/test/test_choose_endpoint.py +++ b/test/test_choose_endpoint.py @@ -12,10 +12,11 @@ EP3 = "http://ep3:11434" LLAMA_EP = "http://llama:8080/v1" -def _make_cfg(endpoints, llama_eps=None, max_conn=2, endpoint_config=None, priority_routing=False): +def _make_cfg(endpoints, llama_eps=None, swap_eps=None, max_conn=2, endpoint_config=None, priority_routing=False): cfg = MagicMock() cfg.endpoints = endpoints cfg.llama_server_endpoints = llama_eps or [] + cfg.llama_swap_endpoints = swap_eps or [] cfg.api_keys = {} cfg.max_concurrent_connections = max_conn cfg.endpoint_config = endpoint_config or {} @@ -46,6 +47,27 @@ class TestChooseEndpointBasic: assert ep == EP1 assert tracking == "llama3.2:latest" + async def test_llama_swap_endpoint_is_a_candidate(self): + swap_ep = "http://swap:8080/v1" + cfg = _make_cfg([EP1], swap_eps=[swap_ep]) + + async def available(ep, *_): + # Only the llama-swap backend advertises this model + return {"org/model:Q4_K_M"} if ep == swap_ep else set() + + async def loaded(ep): + return {"org/model:Q4_K_M"} if ep == swap_ep else set() + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", side_effect=available), + patch.object(router.fetch, "loaded_models", side_effect=loaded), + ): + ep, tracking = await router.choose_endpoint("org/model:Q4_K_M") + assert ep == swap_ep + # llama-swap models are tracked under their normalized name + assert tracking == "model" + async def test_raises_when_no_endpoint_has_model(self): cfg = _make_cfg([EP1, EP2]) with ( diff --git a/test/test_fetch.py b/test/test_fetch.py index 76121e1..dae51e4 100644 --- a/test/test_fetch.py +++ b/test/test_fetch.py @@ -20,10 +20,11 @@ MOCK_OLLAMA_EP = "http://mock-ollama:11434" MOCK_LLAMA_EP = "http://mock-llama:8080/v1" -def _make_cfg(ollama_eps=None, llama_eps=None, api_keys=None): +def _make_cfg(ollama_eps=None, llama_eps=None, swap_eps=None, api_keys=None): cfg = MagicMock() cfg.endpoints = ollama_eps or [MOCK_OLLAMA_EP] cfg.llama_server_endpoints = llama_eps or [MOCK_LLAMA_EP] + cfg.llama_swap_endpoints = swap_eps or [] cfg.api_keys = api_keys or {} cfg.max_concurrent_connections = 2 cfg.router_api_key = None @@ -228,6 +229,30 @@ class TestFetchLoadedModels: models = await router.fetch.loaded_models(MOCK_LLAMA_EP) assert "always-on-model" in models + async def test_llama_swap_reads_running_state_ready(self): + # llama-swap omits the /v1/models status field, so loaded workers come + # from /running (a root route — the /v1 suffix must be stripped). + swap_ep = "http://mock-swap:8080/v1" + cfg = _make_cfg(llama_eps=[], swap_eps=[swap_ep]) + with patch.object(router, "config", cfg), mock_probe() as m: + m.add_get( + "http://mock-swap:8080/running", + payload={"running": [ + {"model": "org/ready-model:Q4_K_M", "state": "ready"}, + {"model": "org/starting-model:Q8_0", "state": "starting"}, + ]}, + ) + models = await router.fetch.loaded_models(swap_ep) + assert models == {"org/ready-model:Q4_K_M"} + + async def test_llama_swap_records_error_on_failure(self): + swap_ep = "http://mock-swap:8080/v1" + cfg = _make_cfg(llama_eps=[], swap_eps=[swap_ep]) + with patch.object(router, "config", cfg), mock_probe() as m: + m.add_get("http://mock-swap:8080/running", status=502, payload={}) + await router.fetch.loaded_models(swap_ep) + assert swap_ep in router._loaded_error_cache + async def test_returns_empty_on_error(self): cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) with patch.object(router, "config", cfg), mock_probe() as m: diff --git a/test/test_llama_swap.py b/test/test_llama_swap.py new file mode 100644 index 0000000..d0427bf --- /dev/null +++ b/test/test_llama_swap.py @@ -0,0 +1,109 @@ +"""Tests for llama-swap specific behavior: unload dispatch + /upstream resolution.""" +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import router +import backends.control as control +import api.openai as openai_api + +SWAP_EP = "http://swap:8080/v1" +SERVER_EP = "http://server:8080/v1" + + +def _cfg(*, server=None, swap=None, api_keys=None): + cfg = MagicMock() + cfg.endpoints = [] + cfg.llama_server_endpoints = server or [] + cfg.llama_swap_endpoints = swap or [] + cfg.api_keys = api_keys or {} + return cfg + + +class _RecordingSession: + """Captures the most recent ``post`` call and returns a 200 response.""" + + def __init__(self, status=200): + self.calls = [] + self._status = status + + def post(self, url, **kwargs): + self.calls.append((url, kwargs)) + resp = MagicMock() + resp.status = self._status + + class _Ctx: + async def __aenter__(self_): + return resp + + async def __aexit__(self_, *exc): + return False + + return _Ctx() + + +class TestUnloadDispatch: + async def test_llama_swap_uses_path_param(self): + sess = _RecordingSession() + cfg = _cfg(swap=[SWAP_EP]) + with ( + patch.object(router, "config", cfg), + patch.object(control, "get_probe_session", lambda ep: sess), + ): + ok = await control.unload_model(SWAP_EP, "org/model:Q4_K_M") + assert ok is True + url, kwargs = sess.calls[0] + # /v1 stripped, model id is a path param, no JSON body + assert url == "http://swap:8080/api/models/unload/org/model:Q4_K_M" + assert kwargs.get("json") is None + + async def test_llama_server_uses_body(self): + sess = _RecordingSession() + cfg = _cfg(server=[SERVER_EP]) + with ( + patch.object(router, "config", cfg), + patch.object(control, "get_probe_session", lambda ep: sess), + ): + ok = await control.unload_model(SERVER_EP, "org/model:Q4_K_M") + assert ok is True + url, kwargs = sess.calls[0] + assert url == "http://server:8080/models/unload" + assert kwargs.get("json") == {"model": "org/model:Q4_K_M"} + + async def test_unload_failure_returns_false(self): + sess = _RecordingSession(status=500) + cfg = _cfg(swap=[SWAP_EP]) + with ( + patch.object(router, "config", cfg), + patch.object(control, "get_probe_session", lambda ep: sess), + ): + ok = await control.unload_model(SWAP_EP, "m") + assert ok is False + + +class TestUpstreamResolution: + async def test_resolves_endpoint_that_advertises_model(self): + cfg = _cfg(swap=[SWAP_EP]) + with ( + patch.object(openai_api, "get_config", lambda: cfg), + patch.object(openai_api.fetch, "available_models", + AsyncMock(return_value={"org/model:Q4_K_M"})), + ): + ep = await openai_api._resolve_llama_swap_endpoint("org/model:Q4_K_M") + assert ep == SWAP_EP + + async def test_returns_none_when_unserved(self): + cfg = _cfg(swap=[SWAP_EP]) + with ( + patch.object(openai_api, "get_config", lambda: cfg), + patch.object(openai_api.fetch, "available_models", + AsyncMock(return_value=set())), + ): + ep = await openai_api._resolve_llama_swap_endpoint("missing") + assert ep is None + + async def test_returns_none_without_swap_endpoints(self): + cfg = _cfg(swap=[]) + with patch.object(openai_api, "get_config", lambda: cfg): + ep = await openai_api._resolve_llama_swap_endpoint("any") + assert ep is None diff --git a/test/test_unit_helpers.py b/test/test_unit_helpers.py index d38eb37..def7082 100644 --- a/test/test_unit_helpers.py +++ b/test/test_unit_helpers.py @@ -277,3 +277,49 @@ class TestGetTrackingModel: with patch.object(router, "config", cfg): result = router.get_tracking_model(ep, "unsloth/model:Q8_0") assert result == "model" + + +class TestLlamaSwapClassification: + def _cfg(self, *, server=None, swap=None): + cfg = MagicMock() + cfg.endpoints = [] + cfg.llama_server_endpoints = server or [] + cfg.llama_swap_endpoints = swap or [] + return cfg + + def test_is_llama_swap_only_for_swap_list(self): + from backends.normalize import is_llama_swap + swap_ep = "http://host:8890/v1" + server_ep = "http://host:8889/v1" + cfg = self._cfg(server=[server_ep], swap=[swap_ep]) + with patch.object(router, "config", cfg): + assert is_llama_swap(swap_ep) is True + assert is_llama_swap(server_ep) is False + + def test_is_llama_server_covers_both(self): + from backends.normalize import is_llama_server + swap_ep = "http://host:8890/v1" + server_ep = "http://host:8889/v1" + cfg = self._cfg(server=[server_ep], swap=[swap_ep]) + with patch.object(router, "config", cfg): + assert is_llama_server(swap_ep) is True + assert is_llama_server(server_ep) is True + assert is_llama_server("http://host:11434") is False + + def test_swap_is_openai_compatible_not_ext(self): + swap_ep = "http://host:8890/v1" + cfg = self._cfg(swap=[swap_ep]) + with patch.object(router, "config", cfg): + assert router.is_openai_compatible(swap_ep) is True + assert router.is_ext_openai_endpoint(swap_ep) is False + + def test_swap_tracking_model_normalized(self): + swap_ep = "http://host:8890/v1" + cfg = self._cfg(swap=[swap_ep]) + with patch.object(router, "config", cfg): + assert router.get_tracking_model(swap_ep, "unsloth/model:Q8_0") == "model" + + def test_llama_endpoints_dedupes_and_orders(self): + from backends.normalize import llama_endpoints + cfg = self._cfg(server=["a", "b"], swap=["b", "c"]) + assert llama_endpoints(cfg) == ["a", "b", "c"] From cef71df3df5403a26bb7dd05ea731e57f918ad3b Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Mon, 15 Jun 2026 19:09:55 +0200 Subject: [PATCH 5/5] feat: add ctx-size for llama-swap models to dashboard --- api/ollama.py | 66 +++++++++++++++++++++++++++++++++++++++-- test/test_llama_swap.py | 22 ++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/api/ollama.py b/api/ollama.py index 0fc98aa..6ae6027 100644 --- a/api/ollama.py +++ b/api/ollama.py @@ -13,6 +13,7 @@ import asyncio import re import time from typing import Optional +from urllib.parse import quote import aiohttp import ollama @@ -976,6 +977,43 @@ async def _fetch_llama_swap_running(endpoint: str) -> list[dict]: ) +# Match the context size in a llama-swap worker's `cmd` string, e.g. +# "llama-server --port 5818 -hf ... --ctx-size 131072 ...". llama.cpp accepts +# both --ctx-size and the short -c alias. +_CTX_SIZE_CMD_RE = re.compile(r"(?:--ctx-size|-c)[=\s]+(\d+)") + + +def _ctx_size_from_cmd(cmd: str) -> int | None: + """Extract n_ctx from a llama-swap worker `cmd` string, or None if absent.""" + if not cmd: + return None + m = _CTX_SIZE_CMD_RE.search(cmd) + return int(m.group(1)) if m else None + + +async def _fetch_llama_swap_nctx(endpoint: str, model_id: str) -> int | None: + """Fallback when a worker's `cmd` lacks --ctx-size: ask the underlying + llama-server via llama-swap's /upstream//props route (plain /props?model= + is not routed by llama-swap and 404s). Returns n_ctx or None on any failure. + """ + config = get_config() + base_url = endpoint.rstrip("/").removesuffix("/v1") + props_url = f"{base_url}/upstream/{quote(model_id, safe='')}/props" + headers = None + api_key = config.api_keys.get(endpoint) + if api_key: + headers = {"Authorization": f"Bearer {api_key}"} + try: + client: aiohttp.ClientSession = get_probe_session(endpoint) + async with client.get(props_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp: + if resp.status == 200: + data = await resp.json() + return data.get("default_generation_settings", {}).get("n_ctx") + except Exception as e: + print(f"[ps_details] Failed to fetch props from {props_url}: {e}") + return None + + @router.get("/api/ps") async def ps_proxy(request: Request): """ @@ -1161,6 +1199,7 @@ async def ps_details_proxy(request: Request): swap_running = await asyncio.gather( *[_fetch_llama_swap_running(ep) for ep in config.llama_swap_endpoints] ) + swap_nctx_fallbacks: list[tuple[str, str, dict]] = [] for endpoint, runlist in zip(config.llama_swap_endpoints, swap_running): for item in runlist: if not isinstance(item, dict) or item.get("state") != "ready": @@ -1170,7 +1209,7 @@ async def ps_details_proxy(request: Request): continue normalized = _normalize_llama_model_name(raw_id) quant = _extract_llama_quant(raw_id) - models.append({ + swap_model = { "name": normalized, "id": normalized, "original_name": raw_id, @@ -1180,6 +1219,29 @@ async def ps_details_proxy(request: Request): "state": item.get("state"), "ttl": item.get("ttl"), "proxy": item.get("proxy"), - }) + } + # llama-swap omits n_ctx from /running, but the worker's launch + # command carries --ctx-size, so parse it from there (no extra + # request). Workers whose cmd lacks the flag fall back to an + # /upstream//props probe below. + n_ctx = _ctx_size_from_cmd(item.get("cmd", "")) + if n_ctx is not None: + swap_model["context_length"] = n_ctx + if 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, normalized)] = n_ctx + else: + swap_nctx_fallbacks.append((endpoint, raw_id, swap_model)) + models.append(swap_model) + + # Resolve ctx for workers whose cmd lacked --ctx-size via /upstream props. + if swap_nctx_fallbacks: + fallback_results = await asyncio.gather( + *[_fetch_llama_swap_nctx(ep, rid) for ep, rid, _ in swap_nctx_fallbacks] + ) + for (ep, _rid, swap_model), n_ctx in zip(swap_nctx_fallbacks, fallback_results): + if n_ctx is not None: + swap_model["context_length"] = n_ctx + if 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(ep, swap_model["id"])] = n_ctx return JSONResponse(content={"models": models}, status_code=200) diff --git a/test/test_llama_swap.py b/test/test_llama_swap.py index d0427bf..74197d4 100644 --- a/test/test_llama_swap.py +++ b/test/test_llama_swap.py @@ -6,6 +6,7 @@ import pytest import router import backends.control as control import api.openai as openai_api +import api.ollama as ollama_api SWAP_EP = "http://swap:8080/v1" SERVER_EP = "http://server:8080/v1" @@ -107,3 +108,24 @@ class TestUpstreamResolution: with patch.object(openai_api, "get_config", lambda: cfg): ep = await openai_api._resolve_llama_swap_endpoint("any") assert ep is None + + +class TestCtxSizeFromCmd: + """ctx-size parsing from a /running worker's launch `cmd` string.""" + + def test_parses_long_flag(self): + cmd = ("llama-server --port 5818\n -hf unsloth/gpt-oss-20b-GGUF:F16\n" + " --ctx-size 131072\n --temp 1.0\n") + assert ollama_api._ctx_size_from_cmd(cmd) == 131072 + + def test_parses_short_flag(self): + assert ollama_api._ctx_size_from_cmd("llama-server -c 8192 --port 1") == 8192 + + def test_parses_equals_form(self): + assert ollama_api._ctx_size_from_cmd("llama-server --ctx-size=4096") == 4096 + + def test_returns_none_when_absent(self): + assert ollama_api._ctx_size_from_cmd("llama-server --port 5818") is None + + def test_returns_none_for_empty(self): + assert ollama_api._ctx_size_from_cmd("") is None