Merge pull request 'dev-1.0.x -> main' (#116) from dev-1.0.x into main
All checks were successful
Build and Publish Docker Image (Semantic Cache) / build (amd64, linux/amd64, docker-amd64) (push) Successful in 4m4s
Build and Publish Docker Image / build (amd64, linux/amd64, docker-amd64) (push) Successful in 1m47s
Build and Publish Docker Image (Semantic Cache) / build (arm64, linux/arm64, docker-arm64) (push) Successful in 17m8s
Build and Publish Docker Image (Semantic Cache) / merge (push) Successful in 38s
Build and Publish Docker Image / build (arm64, linux/arm64, docker-arm64) (push) Successful in 12m40s
Build and Publish Docker Image / merge (push) Successful in 39s
All checks were successful
Build and Publish Docker Image (Semantic Cache) / build (amd64, linux/amd64, docker-amd64) (push) Successful in 4m4s
Build and Publish Docker Image / build (amd64, linux/amd64, docker-amd64) (push) Successful in 1m47s
Build and Publish Docker Image (Semantic Cache) / build (arm64, linux/arm64, docker-arm64) (push) Successful in 17m8s
Build and Publish Docker Image (Semantic Cache) / merge (push) Successful in 38s
Build and Publish Docker Image / build (arm64, linux/arm64, docker-arm64) (push) Successful in 12m40s
Build and Publish Docker Image / merge (push) Successful in 39s
Reviewed-on: https://bitfreedom.net/code/code/nomyo-ai/nomyo-router/pulls/116
This commit is contained in:
commit
fc512fd0d8
22 changed files with 2406 additions and 155 deletions
37
README.md
37
README.md
|
|
@ -132,6 +132,41 @@ This way the Ollama backend servers are utilized more efficient than by simply u
|
|||
|
||||
NOMYO Router also supports OpenAI API compatible v1 backend servers.
|
||||
|
||||
## OpenAI Responses API
|
||||
|
||||
In addition to Chat Completions, NOMYO Router exposes the OpenAI **Responses API**:
|
||||
|
||||
```
|
||||
POST /v1/responses # create a response (stream or non-stream)
|
||||
GET /v1/responses/{id} # retrieve a stored response
|
||||
DELETE /v1/responses/{id} # delete a stored response
|
||||
POST /v1/responses/{id}/cancel # cancel a background response
|
||||
```
|
||||
|
||||
It works transparently across **all** backends. When the routed model lives on a native
|
||||
Responses backend (external OpenAI) the request is forwarded as-is; for Ollama and llama-server the
|
||||
router translates Responses ⇄ Chat Completions in both directions (request, response, and streaming
|
||||
typed SSE events), so clients get a consistent `/v1/responses` surface regardless of backend.
|
||||
|
||||
### Conversation state (`store` / `previous_response_id`)
|
||||
|
||||
The router **owns conversation state itself** (persisted in its SQLite DB) rather than delegating to
|
||||
the upstream provider, so `store` and `previous_response_id` behave identically on every backend.
|
||||
On a follow-up request the router rehydrates the prior turns from its DB and expands them into the
|
||||
conversation; outbound native calls always send `store=false`. Trade-off: this forgoes OpenAI's
|
||||
server-side reasoning-state reuse in exchange for uniform, backend-agnostic chaining.
|
||||
|
||||
### Background mode
|
||||
|
||||
`background:true` (which requires `store:true`) returns immediately with `{"status":"queued"}`; the
|
||||
request runs server-side and the client polls `GET /v1/responses/{id}` until the status reaches a
|
||||
terminal state (`completed` / `failed` / `cancelled`). `POST /v1/responses/{id}/cancel` aborts it.
|
||||
|
||||
Limitations: streaming reconnect-resume via `starting_after` is not yet implemented. In a
|
||||
multi-worker/replica deployment polling works via the shared DB, but `cancel` only reaches the
|
||||
running task in the worker that started it (other workers just mark the stored row cancelled). A
|
||||
background task interrupted by a server restart is reconciled to `failed` on the next startup.
|
||||
|
||||
## Semantic LLM Cache
|
||||
|
||||
NOMYO Router includes an optional semantic cache that serves repeated or semantically similar LLM requests from cache — no endpoint round-trip, no token cost, response in <10 ms.
|
||||
|
|
@ -172,7 +207,7 @@ Each request is keyed on `model + system_prompt` (exact) combined with a weighte
|
|||
|
||||
### Cached routes
|
||||
|
||||
`/api/chat` · `/api/generate` · `/v1/chat/completions` · `/v1/completions`
|
||||
`/api/chat` · `/api/generate` · `/v1/chat/completions` · `/v1/completions` · `/v1/responses`
|
||||
|
||||
### Cache management
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from state import (
|
|||
_affinity_lock,
|
||||
)
|
||||
from sse import subscribe, unsubscribe
|
||||
from backends.normalize import _normalize_llama_model_name
|
||||
from backends.normalize import _normalize_llama_model_name, is_llama_server, llama_endpoints
|
||||
from backends.probe import _endpoint_health
|
||||
|
||||
|
||||
|
|
@ -127,7 +127,6 @@ async def affinity_stats(request: Request):
|
|||
|
||||
now = time.monotonic()
|
||||
entries: list[dict] = []
|
||||
llama_eps = set(config.llama_server_endpoints)
|
||||
async with _affinity_lock:
|
||||
for fp, (ep, mdl, expires_at) in list(_affinity_map.items()):
|
||||
remaining = expires_at - now
|
||||
|
|
@ -136,7 +135,7 @@ async def affinity_stats(request: Request):
|
|||
continue
|
||||
# Mirror the normalisation used by /api/ps_details so the dashboard
|
||||
# can join affinity entries to PS rows by (endpoint, model).
|
||||
display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl
|
||||
display_model = _normalize_llama_model_name(mdl) if is_llama_server(ep) else mdl
|
||||
entries.append({
|
||||
"endpoint": ep,
|
||||
"model": display_model,
|
||||
|
|
@ -175,9 +174,12 @@ async def config_proxy(request: Request):
|
|||
|
||||
ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints])
|
||||
llama_results = []
|
||||
if config.llama_server_endpoints:
|
||||
# llama-server and llama-swap render identically in the dashboard ("llama" rows),
|
||||
# so health-check both and merge them into one list.
|
||||
llama_eps = llama_endpoints(config)
|
||||
if llama_eps:
|
||||
llama_results = await asyncio.gather(
|
||||
*[check(ep) for ep in config.llama_server_endpoints]
|
||||
*[check(ep) for ep in llama_eps]
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -227,7 +229,7 @@ async def health_proxy(request: Request):
|
|||
# purposes. Probing /api/version alone would miss the case where the
|
||||
# Ollama process is up but /api/ps is failing — see issue #83.
|
||||
all_endpoints = list(config.endpoints)
|
||||
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
|
||||
llama_eps_extra = [ep for ep in llama_endpoints(config) if ep not in config.endpoints]
|
||||
all_endpoints += llama_eps_extra
|
||||
|
||||
probe_results = await asyncio.gather(
|
||||
|
|
|
|||
147
api/ollama.py
147
api/ollama.py
|
|
@ -13,6 +13,7 @@ import asyncio
|
|||
import re
|
||||
import time
|
||||
from typing import Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
import aiohttp
|
||||
import ollama
|
||||
|
|
@ -40,9 +41,12 @@ from backends.health import (
|
|||
from backends.normalize import (
|
||||
dedupe_on_keys,
|
||||
is_openai_compatible,
|
||||
is_llama_server,
|
||||
llama_endpoints,
|
||||
_normalize_llama_model_name,
|
||||
_extract_llama_quant,
|
||||
)
|
||||
from backends.control import unload_model
|
||||
from backends.probe import fetch
|
||||
from backends.sessions import _make_openai_client, get_ollama_client, get_probe_session
|
||||
from requests.chat import _make_moe_requests
|
||||
|
|
@ -372,7 +376,7 @@ async def chat_proxy(request: Request):
|
|||
if use_openai:
|
||||
start_ts = time.perf_counter()
|
||||
# Proactive trim: only for small-ctx models we've already seen run out of space
|
||||
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
|
||||
_lookup_model = _normalize_llama_model_name(model) if is_llama_server(endpoint) else model
|
||||
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
|
||||
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_pre_target = int((_known_nctx - _known_nctx // 4) / 1.2)
|
||||
|
|
@ -935,8 +939,8 @@ async def tags_proxy(request: Request):
|
|||
# 1. Query all endpoints for models
|
||||
tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep], skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" in ep]
|
||||
# Also query llama-server endpoints not already covered by config.endpoints
|
||||
llama_eps_for_tags = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
|
||||
# Also query llama-server / llama-swap endpoints not already covered by config.endpoints
|
||||
llama_eps_for_tags = [ep for ep in llama_endpoints(config) if ep not in config.endpoints]
|
||||
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in llama_eps_for_tags]
|
||||
all_models = await asyncio.gather(*tasks)
|
||||
|
||||
|
|
@ -960,27 +964,79 @@ async def tags_proxy(request: Request):
|
|||
)
|
||||
|
||||
|
||||
async def _fetch_llama_swap_running(endpoint: str) -> list[dict]:
|
||||
"""Return the list of ready (`state == "ready"`) workers from a llama-swap
|
||||
endpoint's `/running` route. llama-swap omits the per-model `status` field on
|
||||
`/v1/models`, so running workers must be read here instead.
|
||||
"""
|
||||
config = get_config()
|
||||
base_url = endpoint.rstrip("/").removesuffix("/v1")
|
||||
return await fetch.endpoint_details(
|
||||
base_url, "/running", "running", config.api_keys.get(endpoint),
|
||||
skip_error_cache=True, timeout=8,
|
||||
)
|
||||
|
||||
|
||||
# Match the context size in a llama-swap worker's `cmd` string, e.g.
|
||||
# "llama-server --port 5818 -hf ... --ctx-size 131072 ...". llama.cpp accepts
|
||||
# both --ctx-size and the short -c alias.
|
||||
_CTX_SIZE_CMD_RE = re.compile(r"(?:--ctx-size|-c)[=\s]+(\d+)")
|
||||
|
||||
|
||||
def _ctx_size_from_cmd(cmd: str) -> int | None:
|
||||
"""Extract n_ctx from a llama-swap worker `cmd` string, or None if absent."""
|
||||
if not cmd:
|
||||
return None
|
||||
m = _CTX_SIZE_CMD_RE.search(cmd)
|
||||
return int(m.group(1)) if m else None
|
||||
|
||||
|
||||
async def _fetch_llama_swap_nctx(endpoint: str, model_id: str) -> int | None:
|
||||
"""Fallback when a worker's `cmd` lacks --ctx-size: ask the underlying
|
||||
llama-server via llama-swap's /upstream/<model>/props route (plain /props?model=
|
||||
is not routed by llama-swap and 404s). Returns n_ctx or None on any failure.
|
||||
"""
|
||||
config = get_config()
|
||||
base_url = endpoint.rstrip("/").removesuffix("/v1")
|
||||
props_url = f"{base_url}/upstream/{quote(model_id, safe='')}/props"
|
||||
headers = None
|
||||
api_key = config.api_keys.get(endpoint)
|
||||
if api_key:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
try:
|
||||
client: aiohttp.ClientSession = get_probe_session(endpoint)
|
||||
async with client.get(props_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return data.get("default_generation_settings", {}).get("n_ctx")
|
||||
except Exception as e:
|
||||
print(f"[ps_details] Failed to fetch props from {props_url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/api/ps")
|
||||
async def ps_proxy(request: Request):
|
||||
"""
|
||||
Proxy a ps request to all Ollama and llama-server endpoints and reply a unique list of all running models.
|
||||
Proxy a ps request to all Ollama, llama-server and llama-swap endpoints and reply a unique list of all running models.
|
||||
|
||||
For Ollama endpoints: queries /api/ps
|
||||
For llama-server endpoints: queries /v1/models with status.value == "loaded"
|
||||
For llama-swap endpoints: queries /running (state == "ready")
|
||||
"""
|
||||
config = get_config()
|
||||
# 1. Query Ollama endpoints for running models via /api/ps
|
||||
ollama_tasks = [fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
|
||||
# 2. Query llama-server endpoints for loaded models via /v1/models
|
||||
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
||||
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
||||
llama_tasks = [
|
||||
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
|
||||
for ep in all_llama_endpoints
|
||||
for ep in config.llama_server_endpoints
|
||||
]
|
||||
# 3. Query llama-swap endpoints for running workers via /running
|
||||
swap_tasks = [_fetch_llama_swap_running(ep) for ep in config.llama_swap_endpoints]
|
||||
|
||||
ollama_loaded = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
|
||||
llama_loaded = await asyncio.gather(*llama_tasks) if llama_tasks else []
|
||||
swap_running = await asyncio.gather(*swap_tasks) if swap_tasks else []
|
||||
|
||||
models = {'models': []}
|
||||
# Add Ollama models (if any)
|
||||
|
|
@ -1003,6 +1059,21 @@ async def ps_proxy(request: Request):
|
|||
"status": item.get("status"),
|
||||
"details": {"quantization_level": quant} if quant else {}
|
||||
})
|
||||
# Add llama-swap running workers (already filtered on state == "ready")
|
||||
if swap_running:
|
||||
for runlist in swap_running:
|
||||
for item in runlist:
|
||||
if item.get("state") != "ready":
|
||||
continue
|
||||
raw_id = item.get("model", "")
|
||||
normalized = _normalize_llama_model_name(raw_id)
|
||||
quant = _extract_llama_quant(raw_id)
|
||||
models['models'].append({
|
||||
"name": normalized,
|
||||
"id": normalized,
|
||||
"digest": "",
|
||||
"details": {"quantization_level": quant} if quant else {}
|
||||
})
|
||||
|
||||
# 3. Return a JSONResponse with deduplicated currently deployed models
|
||||
# Deduplicate on 'name' rather than 'digest': llama-server models always
|
||||
|
|
@ -1101,16 +1172,7 @@ async def ps_details_proxy(request: Request):
|
|||
is_generation = "temperature" in dgs
|
||||
|
||||
if is_sleeping:
|
||||
unload_url = f"{base_url}/models/unload"
|
||||
try:
|
||||
async with client.post(
|
||||
unload_url,
|
||||
json={"model": model_id},
|
||||
headers=headers,
|
||||
) as unload_resp:
|
||||
print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}")
|
||||
except Exception as ue:
|
||||
print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}")
|
||||
await unload_model(endpoint, model_id)
|
||||
|
||||
return n_ctx, is_sleeping, is_generation
|
||||
except Exception as e:
|
||||
|
|
@ -1131,4 +1193,55 @@ async def ps_details_proxy(request: Request):
|
|||
if not is_sleeping:
|
||||
models.append(model_dict)
|
||||
|
||||
# Add llama-swap running workers (read from /running; no status/props/auto-unload —
|
||||
# llama-swap omits the status field on /v1/models and manages its own TTL eviction).
|
||||
if config.llama_swap_endpoints:
|
||||
swap_running = await asyncio.gather(
|
||||
*[_fetch_llama_swap_running(ep) for ep in config.llama_swap_endpoints]
|
||||
)
|
||||
swap_nctx_fallbacks: list[tuple[str, str, dict]] = []
|
||||
for endpoint, runlist in zip(config.llama_swap_endpoints, swap_running):
|
||||
for item in runlist:
|
||||
if not isinstance(item, dict) or item.get("state") != "ready":
|
||||
continue
|
||||
raw_id = item.get("model", "")
|
||||
if not raw_id:
|
||||
continue
|
||||
normalized = _normalize_llama_model_name(raw_id)
|
||||
quant = _extract_llama_quant(raw_id)
|
||||
swap_model = {
|
||||
"name": normalized,
|
||||
"id": normalized,
|
||||
"original_name": raw_id,
|
||||
"digest": "",
|
||||
"details": {"quantization_level": quant} if quant else {},
|
||||
"endpoint": endpoint,
|
||||
"state": item.get("state"),
|
||||
"ttl": item.get("ttl"),
|
||||
"proxy": item.get("proxy"),
|
||||
}
|
||||
# llama-swap omits n_ctx from /running, but the worker's launch
|
||||
# command carries --ctx-size, so parse it from there (no extra
|
||||
# request). Workers whose cmd lacks the flag fall back to an
|
||||
# /upstream/<model>/props probe below.
|
||||
n_ctx = _ctx_size_from_cmd(item.get("cmd", ""))
|
||||
if n_ctx is not None:
|
||||
swap_model["context_length"] = n_ctx
|
||||
if 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_endpoint_nctx[(endpoint, normalized)] = n_ctx
|
||||
else:
|
||||
swap_nctx_fallbacks.append((endpoint, raw_id, swap_model))
|
||||
models.append(swap_model)
|
||||
|
||||
# Resolve ctx for workers whose cmd lacked --ctx-size via /upstream props.
|
||||
if swap_nctx_fallbacks:
|
||||
fallback_results = await asyncio.gather(
|
||||
*[_fetch_llama_swap_nctx(ep, rid) for ep, rid, _ in swap_nctx_fallbacks]
|
||||
)
|
||||
for (ep, _rid, swap_model), n_ctx in zip(swap_nctx_fallbacks, fallback_results):
|
||||
if n_ctx is not None:
|
||||
swap_model["context_length"] = n_ctx
|
||||
if 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_endpoint_nctx[(ep, swap_model["id"])] = n_ctx
|
||||
|
||||
return JSONResponse(content={"models": models}, status_code=200)
|
||||
|
|
|
|||
284
api/openai.py
284
api/openai.py
|
|
@ -34,6 +34,8 @@ from backends.normalize import (
|
|||
ep2base,
|
||||
is_ext_openai_endpoint,
|
||||
is_openai_compatible,
|
||||
is_llama_server,
|
||||
llama_endpoints,
|
||||
_normalize_llama_model_name,
|
||||
)
|
||||
from backends.probe import fetch
|
||||
|
|
@ -46,6 +48,110 @@ from routing import choose_endpoint, decrement_usage
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
async def create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model):
|
||||
"""Call ``chat.completions.create`` with the router's resilience retries.
|
||||
|
||||
Encapsulates the recovery ladder shared by the chat-completions handler and
|
||||
the translated ``/v1/responses`` path:
|
||||
|
||||
* ``does not support tools`` → retry without ``tools``
|
||||
* llama-server context exhaustion → sliding-window message trim, with a
|
||||
second retry that also strips ``tools``/``tool_choice``
|
||||
* backend connection failure → mark (endpoint, model) unhealthy so the next
|
||||
request reroutes, then re-raise
|
||||
* ``image input is not supported`` → strip images and retry
|
||||
|
||||
On unrecoverable failure the endpoint usage counter is decremented and the
|
||||
exception is re-raised. Returns the established async generator / response.
|
||||
"""
|
||||
config = get_config()
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**send_params)
|
||||
except Exception as e:
|
||||
_e_str = str(e)
|
||||
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
|
||||
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
|
||||
if "does not support tools" in _e_str:
|
||||
# Model doesn't support tools — retry without them
|
||||
print(f"[ochat] retry: no tools", flush=True)
|
||||
try:
|
||||
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
|
||||
async_gen = await oclient.chat.completions.create(**params_without_tools)
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif _is_ctx_err:
|
||||
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
|
||||
err_body = getattr(e, "body", {}) or {}
|
||||
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||||
n_ctx_limit = err_detail.get("n_ctx", 0)
|
||||
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
||||
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
|
||||
if not n_ctx_limit:
|
||||
import re as _re
|
||||
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
n_ctx_limit = int(_m.group(1))
|
||||
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
actual_tokens = int(_m.group(1))
|
||||
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
|
||||
if not n_ctx_limit:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
||||
|
||||
msgs_to_trim = send_params.get("messages", [])
|
||||
try:
|
||||
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
||||
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
||||
except Exception as _helper_exc:
|
||||
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
dropped = len(msgs_to_trim) - len(trimmed_messages)
|
||||
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
|
||||
print(f"[ctx-trim] retry-1 ok", flush=True)
|
||||
except Exception as e2:
|
||||
_e2_str = str(e2)
|
||||
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
|
||||
# Still too large — tool definitions likely consuming too many tokens, strip them too
|
||||
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
|
||||
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
|
||||
print(f"[ctx-trim] retry-2 ok", flush=True)
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
else:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif _is_backend_connection_error(e):
|
||||
# Upstream connection failed (e.g. llama-server in router mode
|
||||
# whose delegated worker died). Mark (endpoint, model) so the
|
||||
# next request reroutes; the client will retry this one.
|
||||
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
||||
await _mark_backend_unhealthy(endpoint, model, _e_str)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif "image input is not supported" in _e_str:
|
||||
# Model doesn't support images — strip and retry
|
||||
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
else:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
return async_gen
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def openai_embedding_proxy(request: Request):
|
||||
"""
|
||||
|
|
@ -249,7 +355,7 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
|
||||
send_params = {**params, "messages": resolved_msgs}
|
||||
# Proactive trim: only for small-ctx models we've already seen run out of space
|
||||
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
|
||||
_lookup_model = _normalize_llama_model_name(model) if is_llama_server(endpoint) else model
|
||||
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
|
||||
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2)
|
||||
|
|
@ -260,90 +366,7 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
_dropped = len(_pre_msgs) - len(_pre_trimmed)
|
||||
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
|
||||
send_params = {**send_params, "messages": _pre_trimmed}
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**send_params)
|
||||
except Exception as e:
|
||||
_e_str = str(e)
|
||||
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
|
||||
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
|
||||
if "does not support tools" in _e_str:
|
||||
# Model doesn't support tools — retry without them
|
||||
print(f"[ochat] retry: no tools", flush=True)
|
||||
try:
|
||||
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
|
||||
async_gen = await oclient.chat.completions.create(**params_without_tools)
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif _is_ctx_err:
|
||||
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
|
||||
err_body = getattr(e, "body", {}) or {}
|
||||
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||||
n_ctx_limit = err_detail.get("n_ctx", 0)
|
||||
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
||||
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
|
||||
if not n_ctx_limit:
|
||||
import re as _re
|
||||
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
n_ctx_limit = int(_m.group(1))
|
||||
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
actual_tokens = int(_m.group(1))
|
||||
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
|
||||
if not n_ctx_limit:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
||||
|
||||
msgs_to_trim = send_params.get("messages", [])
|
||||
try:
|
||||
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
||||
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
||||
except Exception as _helper_exc:
|
||||
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
dropped = len(msgs_to_trim) - len(trimmed_messages)
|
||||
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
|
||||
print(f"[ctx-trim] retry-1 ok", flush=True)
|
||||
except Exception as e2:
|
||||
_e2_str = str(e2)
|
||||
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
|
||||
# Still too large — tool definitions likely consuming too many tokens, strip them too
|
||||
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
|
||||
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
|
||||
print(f"[ctx-trim] retry-2 ok", flush=True)
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
else:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif _is_backend_connection_error(e):
|
||||
# Upstream connection failed (e.g. llama-server in router mode
|
||||
# whose delegated worker died). Mark (endpoint, model) so the
|
||||
# next request reroutes; the client will retry this one.
|
||||
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
||||
await _mark_backend_unhealthy(endpoint, model, _e_str)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif "image input is not supported" in _e_str:
|
||||
# Model doesn't support images — strip and retry
|
||||
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
else:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
async_gen = await create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model)
|
||||
|
||||
# 4. Async generator — only streams the already-established async_gen
|
||||
async def stream_ochat_response():
|
||||
|
|
@ -637,9 +660,9 @@ async def openai_models_proxy(request: Request):
|
|||
ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
|
||||
# 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models
|
||||
ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)]
|
||||
# 3. Query llama-server endpoints for loaded models via /v1/models
|
||||
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
||||
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
||||
# 3. Query llama-server / llama-swap endpoints for advertised models via /v1/models
|
||||
# Also query endpoints that may not be in config.endpoints
|
||||
all_llama_endpoints = llama_endpoints(config)
|
||||
llama_tasks = [
|
||||
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
|
||||
for ep in all_llama_endpoints
|
||||
|
|
@ -762,10 +785,10 @@ async def rerank_proxy(request: Request):
|
|||
upstream_payload[optional_key] = payload[optional_key]
|
||||
|
||||
# Determine upstream URL:
|
||||
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
|
||||
# llama-server / llama-swap expose /v1/rerank (base already contains /v1)
|
||||
# External OpenAI endpoints expose /rerank under their /v1 base
|
||||
if endpoint in config.llama_server_endpoints:
|
||||
# llama-server: endpoint may or may not already contain /v1
|
||||
if is_llama_server(endpoint):
|
||||
# llama-server / llama-swap: endpoint may or may not already contain /v1
|
||||
if "/v1" in endpoint:
|
||||
rerank_url = f"{endpoint}/rerank"
|
||||
else:
|
||||
|
|
@ -802,3 +825,82 @@ async def rerank_proxy(request: Request):
|
|||
return JSONResponse(content=data)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
|
||||
async def _resolve_llama_swap_endpoint(model_id: str) -> str | None:
|
||||
"""Pick the llama-swap endpoint that serves ``model_id``.
|
||||
|
||||
Prefers an endpoint that already has the worker running; falls back to any
|
||||
that advertises the model. Returns None if none do.
|
||||
"""
|
||||
config = get_config()
|
||||
swap_eps = config.llama_swap_endpoints
|
||||
if not swap_eps:
|
||||
return None
|
||||
|
||||
advertised = await asyncio.gather(
|
||||
*[fetch.available_models(ep, config.api_keys.get(ep)) for ep in swap_eps]
|
||||
)
|
||||
candidates = [ep for ep, models in zip(swap_eps, advertised) if model_id in models]
|
||||
if not candidates:
|
||||
return None
|
||||
if len(candidates) == 1:
|
||||
return candidates[0]
|
||||
|
||||
loaded = await asyncio.gather(*[fetch.loaded_models(ep) for ep in candidates])
|
||||
for ep, lm in zip(candidates, loaded):
|
||||
if model_id in lm:
|
||||
return ep
|
||||
return candidates[0]
|
||||
|
||||
|
||||
@router.api_route("/upstream/{model_id}/{path:path}", methods=["GET", "POST"])
|
||||
async def llama_swap_upstream(model_id: str, path: str, request: Request):
|
||||
"""Bypass llama-swap and reach a model's underlying llama-server worker directly
|
||||
via llama-swap's ``/upstream/:model_id`` route.
|
||||
|
||||
Lets clients use llama-server features that llama-swap itself does not forward
|
||||
(e.g. token-array prompts), while still letting the router pick the backend that
|
||||
actually hosts the model. ``/upstream`` is a root route, so the ``/v1`` suffix is
|
||||
stripped from the configured endpoint.
|
||||
"""
|
||||
config = get_config()
|
||||
endpoint = await _resolve_llama_swap_endpoint(model_id)
|
||||
if endpoint is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No configured llama-swap endpoint serves model '{model_id}'.",
|
||||
)
|
||||
|
||||
base_url = endpoint.rstrip("/").removesuffix("/v1")
|
||||
url = f"{base_url}/upstream/{model_id}/{path}"
|
||||
if request.url.query:
|
||||
url = f"{url}?{request.url.query}"
|
||||
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
content_type = request.headers.get("content-type")
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
api_key = config.api_keys.get(endpoint)
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
|
||||
body = await request.body()
|
||||
client: aiohttp.ClientSession = get_session(endpoint)
|
||||
try:
|
||||
resp = await client.request(request.method, url, data=body or None, headers=headers)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream request to {url} failed: {e}")
|
||||
|
||||
async def _iter():
|
||||
try:
|
||||
async for chunk in resp.content.iter_any():
|
||||
yield chunk
|
||||
finally:
|
||||
resp.release()
|
||||
|
||||
return StreamingResponse(
|
||||
_iter(),
|
||||
status_code=resp.status,
|
||||
media_type=resp.headers.get("Content-Type"),
|
||||
)
|
||||
|
|
|
|||
398
api/responses.py
Normal file
398
api/responses.py
Normal file
|
|
@ -0,0 +1,398 @@
|
|||
"""OpenAI **Responses API** routes (``/v1/responses`` and its retrieve / delete /
|
||||
cancel companions).
|
||||
|
||||
The router speaks Chat Completions to its backends, so this layer:
|
||||
|
||||
* **native** (external OpenAI): forwards via ``oclient.responses.create`` and
|
||||
streams the SDK's typed events straight back, rewriting the response ``id`` to
|
||||
a router-owned ``resp_`` id so chaining stays router-managed.
|
||||
* **translated** (Ollama / llama-server): converts the request to chat, reuses
|
||||
the resilient ``create_chat_with_retries`` ladder, and re-emits the result as
|
||||
Responses typed SSE events (``requests/responses.py``).
|
||||
|
||||
State (``store`` / ``previous_response_id``) and background-task status live in the
|
||||
router's SQLite DB (``db.py``); the router mints and owns every response id.
|
||||
"""
|
||||
import asyncio
|
||||
import secrets
|
||||
import time
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from cache import get_llm_cache
|
||||
from config import get_config
|
||||
from db import get_db
|
||||
from fingerprint import _conversation_fingerprint
|
||||
from state import token_queue, default_headers
|
||||
from backends.normalize import is_ext_openai_endpoint
|
||||
from backends.sessions import _make_openai_client
|
||||
from routing import choose_endpoint, decrement_usage
|
||||
from api.openai import create_chat_with_retries
|
||||
from requests.responses import (
|
||||
ChatToResponsesStream,
|
||||
build_response_object,
|
||||
chat_message_to_output_items,
|
||||
messages_to_responses_input,
|
||||
responses_input_to_messages,
|
||||
responses_object_to_sse,
|
||||
tools_responses_to_chat,
|
||||
usage_chat_to_responses,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# In-memory handles for background tasks so /cancel can reach a running task in
|
||||
# this worker. Cross-worker cancel falls back to marking the DB row cancelled.
|
||||
_background_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# small helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def _usage_tokens(usage):
|
||||
"""Return ``(prompt, completion)`` tokens from a chat- or responses-shaped usage."""
|
||||
if not usage:
|
||||
return 0, 0
|
||||
if "input_tokens" in usage:
|
||||
return usage.get("input_tokens", 0) or 0, usage.get("output_tokens", 0) or 0
|
||||
return usage.get("prompt_tokens", 0) or 0, usage.get("completion_tokens", 0) or 0
|
||||
|
||||
|
||||
def _text_format_to_response_format(text):
|
||||
"""Map Responses ``text.format`` → Chat Completions ``response_format`` (best effort)."""
|
||||
if not isinstance(text, dict):
|
||||
return None
|
||||
fmt = text.get("format")
|
||||
if not isinstance(fmt, dict):
|
||||
return None
|
||||
ftype = fmt.get("type")
|
||||
if ftype == "json_object":
|
||||
return {"type": "json_object"}
|
||||
if ftype == "json_schema":
|
||||
return {"type": "json_schema", "json_schema": {
|
||||
k: fmt[k] for k in ("name", "schema", "strict", "description") if k in fmt
|
||||
}}
|
||||
return None
|
||||
|
||||
|
||||
def _native_usage_from_response(data):
|
||||
return data.get("usage")
|
||||
|
||||
|
||||
async def _resolve_history_messages(previous_response_id):
|
||||
"""Rebuild prior-turn chat messages from the stored response chain."""
|
||||
if not previous_response_id:
|
||||
return []
|
||||
db = get_db()
|
||||
chain = await db.get_response_chain(previous_response_id)
|
||||
messages = []
|
||||
for turn in chain:
|
||||
# Each turn stored the chat messages that produced it + its output items.
|
||||
for m in turn.get("input_messages") or []:
|
||||
messages.append(m)
|
||||
for item in turn.get("output_items") or []:
|
||||
if item.get("type") == "message":
|
||||
text = "".join(
|
||||
p.get("text", "") for p in item.get("content") or []
|
||||
if p.get("type") == "output_text"
|
||||
)
|
||||
if text:
|
||||
messages.append({"role": "assistant", "content": text})
|
||||
elif item.get("type") == "function_call":
|
||||
messages.append({
|
||||
"role": "assistant", "content": None,
|
||||
"tool_calls": [{"id": item.get("call_id"), "type": "function",
|
||||
"function": {"name": item.get("name"),
|
||||
"arguments": item.get("arguments", "")}}],
|
||||
})
|
||||
return messages
|
||||
|
||||
|
||||
class _NativeStream:
|
||||
"""Re-emit an SDK Responses event stream, rewriting the response id and
|
||||
capturing the final output/usage for storage."""
|
||||
|
||||
def __init__(self, response_id):
|
||||
self.response_id = response_id
|
||||
self.output_items = []
|
||||
self.usage = None
|
||||
|
||||
async def events(self, sdk_gen):
|
||||
async for event in sdk_gen:
|
||||
data = event.model_dump() if hasattr(event, "model_dump") else event
|
||||
etype = data.get("type", "")
|
||||
resp = data.get("response")
|
||||
if isinstance(resp, dict) and resp.get("id"):
|
||||
resp["id"] = self.response_id
|
||||
if etype in ("response.completed", "response.incomplete", "response.failed") \
|
||||
and isinstance(resp, dict):
|
||||
self.output_items = resp.get("output", []) or []
|
||||
self.usage = resp.get("usage")
|
||||
yield f"event: {etype}\ndata: {orjson.dumps(data).decode('utf-8')}\n\n".encode("utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# backend execution (non-streaming, used by background + non-stream sync)
|
||||
# ---------------------------------------------------------------------------
|
||||
async def _run_to_completion(*, native, oclient, endpoint, model, tracking_model,
|
||||
send_params, native_params):
|
||||
"""Drive the backend to completion (no client streaming).
|
||||
|
||||
Returns ``(output_items, usage)`` where usage is responses-shaped. Caller is
|
||||
responsible for ``decrement_usage`` (translated failures self-decrement inside
|
||||
``create_chat_with_retries``)."""
|
||||
if native:
|
||||
resp_obj = await oclient.responses.create(stream=False, **native_params)
|
||||
data = resp_obj.model_dump()
|
||||
return data.get("output", []) or [], data.get("usage")
|
||||
async_gen = await create_chat_with_retries(oclient, {**send_params, "stream": False},
|
||||
endpoint, model, tracking_model)
|
||||
message = async_gen.choices[0].message.model_dump() if async_gen.choices else {}
|
||||
output_items = chat_message_to_output_items(message)
|
||||
usage = usage_chat_to_responses(
|
||||
async_gen.usage.model_dump() if async_gen.usage is not None else None
|
||||
)
|
||||
return output_items, usage
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /v1/responses
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/v1/responses")
|
||||
async def openai_responses_proxy(request: Request):
|
||||
config = get_config()
|
||||
try:
|
||||
payload = orjson.loads((await request.body()).decode("utf-8"))
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
model = payload.get("model")
|
||||
input_data = payload.get("input")
|
||||
instructions = payload.get("instructions")
|
||||
stream = bool(payload.get("stream"))
|
||||
store = payload.get("store", True)
|
||||
background = bool(payload.get("background"))
|
||||
previous_response_id = payload.get("previous_response_id")
|
||||
tools = payload.get("tools")
|
||||
metadata = payload.get("metadata") or {}
|
||||
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
||||
|
||||
if not model:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'model'")
|
||||
if input_data is None:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'input'")
|
||||
if background and not store:
|
||||
raise HTTPException(status_code=400, detail="background mode requires store=true")
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")[0]
|
||||
|
||||
# Resolve conversation: prior turns (from store) + this turn's input.
|
||||
history = await _resolve_history_messages(previous_response_id)
|
||||
messages = history + responses_input_to_messages(input_data, instructions)
|
||||
|
||||
response_id = f"resp_{secrets.token_hex(24)}"
|
||||
created_at = int(time.time())
|
||||
|
||||
# Cache lookup (foreground only) — before endpoint selection.
|
||||
_cache = get_llm_cache()
|
||||
if _cache is not None and _cache_enabled and not background:
|
||||
cached = await _cache.get_chat("openai_responses", model, messages)
|
||||
if cached is not None:
|
||||
resp_obj = orjson.loads(cached)
|
||||
resp_obj["id"] = response_id
|
||||
if stream:
|
||||
async def _served_cached():
|
||||
yield responses_object_to_sse(resp_obj)
|
||||
return StreamingResponse(_served_cached(), media_type="text/event-stream")
|
||||
return JSONResponse(content=resp_obj)
|
||||
|
||||
# Endpoint selection (reserves a slot — must be released exactly once).
|
||||
_affinity_key = _conversation_fingerprint(model, messages, None)
|
||||
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
||||
oclient = _make_openai_client(endpoint, default_headers=default_headers,
|
||||
api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
native = is_ext_openai_endpoint(endpoint)
|
||||
|
||||
# Build backend params for both shapes.
|
||||
send_params = {"messages": messages, "model": model}
|
||||
_opt = {
|
||||
"temperature": payload.get("temperature"),
|
||||
"top_p": payload.get("top_p"),
|
||||
"max_tokens": payload.get("max_output_tokens"),
|
||||
"tools": tools_responses_to_chat(tools),
|
||||
"tool_choice": payload.get("tool_choice"),
|
||||
"response_format": _text_format_to_response_format(payload.get("text")),
|
||||
}
|
||||
send_params.update({k: v for k, v in _opt.items() if v is not None})
|
||||
|
||||
native_instructions, native_input = messages_to_responses_input(messages)
|
||||
native_params = {"model": model, "input": native_input, "store": False}
|
||||
_nopt = {
|
||||
"instructions": native_instructions,
|
||||
"temperature": payload.get("temperature"),
|
||||
"top_p": payload.get("top_p"),
|
||||
"max_output_tokens": payload.get("max_output_tokens"),
|
||||
"tools": tools,
|
||||
"tool_choice": payload.get("tool_choice"),
|
||||
"text": payload.get("text"),
|
||||
"reasoning": payload.get("reasoning"),
|
||||
}
|
||||
native_params.update({k: v for k, v in _nopt.items() if v is not None})
|
||||
|
||||
async def _persist(status, output_items=None, usage=None, error=None, insert=False):
|
||||
if not store:
|
||||
return
|
||||
db = get_db()
|
||||
if insert:
|
||||
await db.store_response(
|
||||
response_id, previous_response_id=previous_response_id, model=model,
|
||||
status=status, created_at=created_at, input_messages=messages,
|
||||
output_items=output_items, usage=usage, instructions=instructions, error=error)
|
||||
else:
|
||||
await db.update_response_status(response_id, status, output_items=output_items,
|
||||
usage=usage, error=error)
|
||||
|
||||
async def _track(usage):
|
||||
prompt_tok, comp_tok = _usage_tokens(usage)
|
||||
if prompt_tok or comp_tok:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
|
||||
async def _cache_store(output_items, usage):
|
||||
if _cache is None or not _cache_enabled or not output_items:
|
||||
return
|
||||
obj = build_response_object(response_id=response_id, model=model,
|
||||
output_items=output_items, usage=usage,
|
||||
created_at=created_at,
|
||||
previous_response_id=previous_response_id,
|
||||
instructions=instructions, metadata=metadata)
|
||||
try:
|
||||
await _cache.set_chat("openai_responses", model, messages, orjson.dumps(obj))
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (openai_responses) failed: {_ce}")
|
||||
|
||||
# ---- background: run detached, return queued immediately --------------
|
||||
if background:
|
||||
await _persist("queued", insert=True)
|
||||
|
||||
async def _bg_run():
|
||||
try:
|
||||
await get_db().update_response_status(response_id, "in_progress")
|
||||
output_items, usage = await _run_to_completion(
|
||||
native=native, oclient=oclient, endpoint=endpoint, model=model,
|
||||
tracking_model=tracking_model, send_params=send_params,
|
||||
native_params=native_params)
|
||||
await _track(usage)
|
||||
await _persist("completed", output_items=output_items, usage=usage)
|
||||
await _cache_store(output_items, usage)
|
||||
except asyncio.CancelledError:
|
||||
await get_db().update_response_status(response_id, "cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
await get_db().update_response_status(
|
||||
response_id, "failed",
|
||||
error={"message": str(e)[:500], "type": type(e).__name__})
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
_background_tasks.pop(response_id, None)
|
||||
|
||||
task = asyncio.create_task(_bg_run())
|
||||
_background_tasks[response_id] = task
|
||||
queued = build_response_object(response_id=response_id, model=model, output_items=[],
|
||||
status="queued", created_at=created_at,
|
||||
previous_response_id=previous_response_id,
|
||||
instructions=instructions, metadata=metadata)
|
||||
return JSONResponse(content=queued, status_code=200)
|
||||
|
||||
# ---- streaming sync ----------------------------------------------------
|
||||
if stream:
|
||||
if native:
|
||||
source = await oclient.responses.create(stream=True, **native_params)
|
||||
translator = _NativeStream(response_id)
|
||||
else:
|
||||
source = await create_chat_with_retries(
|
||||
oclient, {**send_params, "stream": True,
|
||||
"stream_options": {"include_usage": True}},
|
||||
endpoint, model, tracking_model)
|
||||
translator = ChatToResponsesStream(
|
||||
response_id, model, created_at=created_at,
|
||||
previous_response_id=previous_response_id, instructions=instructions,
|
||||
metadata=metadata)
|
||||
|
||||
async def _stream():
|
||||
await _persist("in_progress", insert=True)
|
||||
try:
|
||||
async for sse in translator.events(source):
|
||||
yield sse
|
||||
await _track(translator.usage)
|
||||
await _persist("completed", output_items=translator.output_items,
|
||||
usage=translator.usage)
|
||||
await _cache_store(translator.output_items, translator.usage)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
return StreamingResponse(_stream(), media_type="text/event-stream")
|
||||
|
||||
# ---- non-streaming sync ------------------------------------------------
|
||||
try:
|
||||
output_items, usage = await _run_to_completion(
|
||||
native=native, oclient=oclient, endpoint=endpoint, model=model,
|
||||
tracking_model=tracking_model, send_params=send_params,
|
||||
native_params=native_params)
|
||||
await _track(usage)
|
||||
await _persist("completed", output_items=output_items, usage=usage, insert=True)
|
||||
await _cache_store(output_items, usage)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
resp_obj = build_response_object(
|
||||
response_id=response_id, model=model, output_items=output_items, usage=usage,
|
||||
created_at=created_at, previous_response_id=previous_response_id,
|
||||
instructions=instructions, metadata=metadata)
|
||||
return JSONResponse(content=resp_obj)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET / DELETE / cancel
|
||||
# ---------------------------------------------------------------------------
|
||||
def _stored_to_response_object(row):
|
||||
return build_response_object(
|
||||
response_id=row["response_id"], model=row.get("model"),
|
||||
output_items=row.get("output_items") or [], usage=row.get("usage"),
|
||||
status=row.get("status") or "completed", created_at=row.get("created_at"),
|
||||
previous_response_id=row.get("previous_response_id"),
|
||||
instructions=row.get("instructions"), error=row.get("error"))
|
||||
|
||||
|
||||
@router.get("/v1/responses/{response_id}")
|
||||
async def get_response(response_id: str):
|
||||
row = await get_db().get_response(response_id)
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
|
||||
return JSONResponse(content=_stored_to_response_object(row))
|
||||
|
||||
|
||||
@router.delete("/v1/responses/{response_id}")
|
||||
async def delete_response(response_id: str):
|
||||
deleted = await get_db().delete_response(response_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
|
||||
return JSONResponse(content={"id": response_id, "object": "response.deleted", "deleted": True})
|
||||
|
||||
|
||||
@router.post("/v1/responses/{response_id}/cancel")
|
||||
async def cancel_response(response_id: str):
|
||||
row = await get_db().get_response(response_id)
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
|
||||
# Cancel the running task if it lives in this worker; otherwise just mark the
|
||||
# DB row so a polling client sees a terminal state (cross-worker limitation).
|
||||
task = _background_tasks.get(response_id)
|
||||
if task is not None and not task.done():
|
||||
task.cancel()
|
||||
elif row.get("status") in ("queued", "in_progress"):
|
||||
await get_db().update_response_status(response_id, "cancelled")
|
||||
row = await get_db().get_response(response_id)
|
||||
return JSONResponse(content=_stored_to_response_object(row))
|
||||
50
backends/control.py
Normal file
50
backends/control.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""Backend control operations (model unload).
|
||||
|
||||
llama-server and llama-swap evict a resident model through different routes:
|
||||
* llama-server → ``POST {base}/models/unload`` with body ``{"model": id}``
|
||||
* llama-swap → ``POST {base}/api/models/unload/{id}`` (path parameter)
|
||||
|
||||
``unload_model`` dispatches on the configured backend type so callers don't
|
||||
have to know which one they are talking to. Both routes live at the endpoint
|
||||
root, so any ``/v1`` suffix is stripped first.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from config import get_config
|
||||
from state import default_headers
|
||||
from backends.sessions import get_probe_session
|
||||
from backends.normalize import is_llama_swap
|
||||
from backends.health import _format_connection_issue
|
||||
|
||||
|
||||
async def unload_model(endpoint: str, model_id: str) -> bool:
|
||||
"""Ask ``endpoint`` to unload ``model_id``. Returns True on a 2xx response.
|
||||
|
||||
``model_id`` must be the backend's native model identifier (the raw HF id
|
||||
for llama-server / llama-swap), not the router-normalized display name.
|
||||
"""
|
||||
cfg = get_config()
|
||||
base_url = endpoint.rstrip("/").removesuffix("/v1")
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
api_key: Optional[str] = cfg.api_keys.get(endpoint)
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
|
||||
if is_llama_swap(endpoint):
|
||||
url = f"{base_url}/api/models/unload/{model_id}"
|
||||
json_body = None
|
||||
else:
|
||||
url = f"{base_url}/models/unload"
|
||||
json_body = {"model": model_id}
|
||||
|
||||
client: aiohttp.ClientSession = get_probe_session(endpoint)
|
||||
try:
|
||||
async with client.post(url, json=json_body, headers=headers) as resp:
|
||||
ok = resp.status < 400
|
||||
print(f"[unload_model] {model_id} on {endpoint}: {resp.status}")
|
||||
return ok
|
||||
except Exception as e:
|
||||
print(f"[unload_model] {_format_connection_issue(url, e)}")
|
||||
return False
|
||||
|
|
@ -50,27 +50,46 @@ def dedupe_on_keys(dicts, key_fields):
|
|||
return out
|
||||
|
||||
|
||||
def is_llama_swap(endpoint: str) -> bool:
|
||||
"""True if the endpoint is a configured llama-swap front."""
|
||||
return endpoint in get_config().llama_swap_endpoints
|
||||
|
||||
|
||||
def is_llama_server(endpoint: str) -> bool:
|
||||
"""True for a llama.cpp llama-server OR a llama-swap front.
|
||||
|
||||
Both speak the same OpenAI-compatible surface, so the router treats them
|
||||
identically everywhere except loaded-model detection and model unload.
|
||||
"""
|
||||
cfg = get_config()
|
||||
return endpoint in cfg.llama_server_endpoints or endpoint in cfg.llama_swap_endpoints
|
||||
|
||||
|
||||
def llama_endpoints(cfg) -> list:
|
||||
"""Combined, de-duplicated llama-server + llama-swap endpoints (order preserved)."""
|
||||
return list(dict.fromkeys([*cfg.llama_server_endpoints, *cfg.llama_swap_endpoints]))
|
||||
|
||||
|
||||
def is_ext_openai_endpoint(endpoint: str) -> bool:
|
||||
"""
|
||||
Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama or llama-server).
|
||||
Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama, llama-server or llama-swap).
|
||||
|
||||
Returns True for:
|
||||
- External services like OpenAI.com, Groq, etc.
|
||||
|
||||
Returns False for:
|
||||
- Ollama endpoints (without /v1, or with /v1 but default port 11434)
|
||||
- llama-server endpoints (explicitly configured in llama_server_endpoints)
|
||||
- llama-server / llama-swap endpoints (explicitly configured)
|
||||
"""
|
||||
cfg = get_config()
|
||||
# Check if it's a llama-server endpoint (has /v1 and is in the configured list)
|
||||
if endpoint in cfg.llama_server_endpoints:
|
||||
# Check if it's a llama-server / llama-swap endpoint (has /v1 and is in a configured list)
|
||||
if is_llama_server(endpoint):
|
||||
return False
|
||||
|
||||
if "/v1" not in endpoint:
|
||||
return False
|
||||
|
||||
base_endpoint = endpoint.replace('/v1', '')
|
||||
if base_endpoint in cfg.endpoints:
|
||||
if base_endpoint in get_config().endpoints:
|
||||
return False # It's Ollama's /v1
|
||||
|
||||
# Check for default Ollama port
|
||||
|
|
@ -83,9 +102,9 @@ def is_ext_openai_endpoint(endpoint: str) -> bool:
|
|||
def is_openai_compatible(endpoint: str) -> bool:
|
||||
"""
|
||||
Return True if the endpoint speaks the OpenAI API (not native Ollama).
|
||||
This includes external OpenAI endpoints AND llama-server endpoints.
|
||||
This includes external OpenAI endpoints AND llama-server / llama-swap endpoints.
|
||||
"""
|
||||
return "/v1" in endpoint or endpoint in get_config().llama_server_endpoints
|
||||
return "/v1" in endpoint or is_llama_server(endpoint)
|
||||
|
||||
|
||||
def get_tracking_model(endpoint: str, model: str) -> str:
|
||||
|
|
@ -102,8 +121,8 @@ def get_tracking_model(endpoint: str, model: str) -> str:
|
|||
if is_ext_openai_endpoint(endpoint):
|
||||
return model
|
||||
|
||||
# llama-server endpoints use normalized names in PS
|
||||
if endpoint in get_config().llama_server_endpoints:
|
||||
# llama-server / llama-swap endpoints use normalized names in PS
|
||||
if is_llama_server(endpoint):
|
||||
return _normalize_llama_model_name(model)
|
||||
|
||||
# Ollama endpoints: append ":latest" if no version suffix
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ from backends.health import (
|
|||
_format_connection_issue,
|
||||
_is_llama_model_loaded,
|
||||
)
|
||||
from backends.normalize import is_ext_openai_endpoint, is_openai_compatible
|
||||
from backends.normalize import is_ext_openai_endpoint, is_openai_compatible, is_llama_server, is_llama_swap
|
||||
|
||||
|
||||
class fetch:
|
||||
|
|
@ -61,10 +61,10 @@ class fetch:
|
|||
headers["Authorization"] = "Bearer " + api_key
|
||||
|
||||
ep_base = endpoint.rstrip("/")
|
||||
if endpoint in cfg.llama_server_endpoints and "/v1" not in endpoint:
|
||||
if is_llama_server(endpoint) and "/v1" not in endpoint:
|
||||
endpoint_url = f"{ep_base}/v1/models"
|
||||
key = "data"
|
||||
elif "/v1" in endpoint or endpoint in cfg.llama_server_endpoints:
|
||||
elif "/v1" in endpoint or is_llama_server(endpoint):
|
||||
endpoint_url = f"{ep_base}/models"
|
||||
key = "data"
|
||||
else:
|
||||
|
|
@ -194,6 +194,38 @@ class fetch:
|
|||
client: aiohttp.ClientSession = get_probe_session(endpoint)
|
||||
cfg = get_config()
|
||||
|
||||
# llama-swap: loaded/running workers are reported at /running (state == "ready"),
|
||||
# NOT via a status field on /v1/models (which it omits). /running is a root route,
|
||||
# so strip any /v1 suffix from the configured endpoint.
|
||||
if is_llama_swap(endpoint):
|
||||
base_url = endpoint.rstrip("/").removesuffix("/v1")
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
api_key = cfg.api_keys.get(endpoint)
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
try:
|
||||
async with client.get(f"{base_url}/running", headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
|
||||
models = {
|
||||
item.get("model")
|
||||
for item in data.get("running", [])
|
||||
if item.get("model") and item.get("state") == "ready"
|
||||
}
|
||||
|
||||
async with _loaded_models_cache_lock:
|
||||
_loaded_models_cache[endpoint] = (models, time.time())
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache.pop(endpoint, None)
|
||||
return models
|
||||
except Exception as e:
|
||||
message = _format_connection_issue(f"{base_url}/running", e)
|
||||
print(f"[fetch.loaded_models] {message}")
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
|
||||
# Check if this is a llama-server endpoint
|
||||
if endpoint in cfg.llama_server_endpoints:
|
||||
# Query /v1/models for llama-server. Send the configured key as a
|
||||
|
|
|
|||
|
|
@ -23,6 +23,10 @@ class Config(BaseSettings):
|
|||
)
|
||||
# List of llama-server endpoints (OpenAI-compatible with /v1/models status info)
|
||||
llama_server_endpoints: List[str] = Field(default_factory=list)
|
||||
# List of llama-swap endpoints (OpenAI-compatible front for multiple llama-server
|
||||
# workers). Same surface as llama_server_endpoints, but loaded models are read from
|
||||
# /running (not /v1/models status) and unload uses POST /api/models/unload/:model_id.
|
||||
llama_swap_endpoints: List[str] = Field(default_factory=list)
|
||||
# Max concurrent connections per endpoint‑model pair, see OLLAMA_NUM_PARALLEL
|
||||
max_concurrent_connections: int = 1
|
||||
# Per-endpoint overrides: {endpoint_url: {max_concurrent_connections: N}}
|
||||
|
|
|
|||
13
config.yaml
13
config.yaml
|
|
@ -6,7 +6,15 @@ endpoints:
|
|||
- https://api.openai.com/v1
|
||||
|
||||
llama_server_endpoints:
|
||||
- http://192.168.0.50:8889/v1
|
||||
- http://192.168.0.51:8889/v1
|
||||
|
||||
# llama-swap endpoints (OpenAI-compatible front for multiple llama-server workers).
|
||||
# Same surface as llama_server_endpoints, but the router reads loaded/running workers
|
||||
# from /running (state == "ready") instead of a /v1/models status field, and unloads via
|
||||
# POST /api/models/unload/:model_id. The router also exposes /upstream/:model_id/<path>
|
||||
# to bypass llama-swap and reach a model's underlying llama-server worker directly.
|
||||
llama_swap_endpoints:
|
||||
- http://192.168.0.52:8890/v1
|
||||
|
||||
# Maximum concurrent connections *per endpoint‑model pair* (equals to OLLAMA_NUM_PARALLEL)
|
||||
# This is the global default; individual endpoints can override it via endpoint_config below.
|
||||
|
|
@ -57,7 +65,8 @@ api_keys:
|
|||
"http://192.168.0.51:11434": "ollama"
|
||||
"http://192.168.0.52:11434": "ollama"
|
||||
"https://api.openai.com/v1": "${OPENAI_KEY}"
|
||||
"http://192.168.0.50:8889/v1": "llama"
|
||||
"http://192.168.0.51:8889/v1": "llama"
|
||||
"http://192.168.0.52:8889/v1": "llama-swap"
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# Semantic LLM Cache (optional — disabled by default)
|
||||
|
|
|
|||
175
db.py
175
db.py
|
|
@ -1,4 +1,4 @@
|
|||
import aiosqlite, asyncio
|
||||
import aiosqlite, asyncio, orjson
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
|
@ -75,6 +75,24 @@ class TokenDatabase:
|
|||
''')
|
||||
await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_timestamp ON token_time_series(timestamp)')
|
||||
await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_model_ts ON token_time_series(model, timestamp)')
|
||||
# Responses API state — the router owns conversation state for the
|
||||
# /v1/responses family (store / previous_response_id) and tracks
|
||||
# background-task status here so polling survives across workers.
|
||||
await db.execute('''
|
||||
CREATE TABLE IF NOT EXISTS stored_responses (
|
||||
response_id TEXT PRIMARY KEY,
|
||||
previous_response_id TEXT,
|
||||
model TEXT,
|
||||
status TEXT,
|
||||
created_at INTEGER,
|
||||
input_messages TEXT,
|
||||
output_items TEXT,
|
||||
usage TEXT,
|
||||
instructions TEXT,
|
||||
error TEXT
|
||||
)
|
||||
''')
|
||||
await db.execute('CREATE INDEX IF NOT EXISTS idx_stored_responses_prev ON stored_responses(previous_response_id)')
|
||||
await db.commit()
|
||||
|
||||
async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int):
|
||||
|
|
@ -319,3 +337,158 @@ class TokenDatabase:
|
|||
await db.commit()
|
||||
|
||||
return aggregated_count
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Responses API state (store / previous_response_id / background)
|
||||
# -----------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _row_to_response(row) -> dict:
|
||||
"""Map a stored_responses row to a plain dict, decoding JSON columns."""
|
||||
def _loads(val):
|
||||
if val is None:
|
||||
return None
|
||||
try:
|
||||
return orjson.loads(val)
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
return None
|
||||
return {
|
||||
'response_id': row[0],
|
||||
'previous_response_id': row[1],
|
||||
'model': row[2],
|
||||
'status': row[3],
|
||||
'created_at': row[4],
|
||||
'input_messages': _loads(row[5]),
|
||||
'output_items': _loads(row[6]),
|
||||
'usage': _loads(row[7]),
|
||||
'instructions': row[8],
|
||||
'error': _loads(row[9]),
|
||||
}
|
||||
|
||||
async def store_response(
|
||||
self,
|
||||
response_id: str,
|
||||
*,
|
||||
previous_response_id: Optional[str],
|
||||
model: str,
|
||||
status: str,
|
||||
created_at: int,
|
||||
input_messages: list,
|
||||
output_items: Optional[list] = None,
|
||||
usage: Optional[dict] = None,
|
||||
instructions: Optional[str] = None,
|
||||
error: Optional[dict] = None,
|
||||
):
|
||||
"""Insert or replace a stored Responses-API response row."""
|
||||
db = await self._get_connection()
|
||||
async with self._operation_lock:
|
||||
await db.execute('''
|
||||
INSERT INTO stored_responses
|
||||
(response_id, previous_response_id, model, status, created_at,
|
||||
input_messages, output_items, usage, instructions, error)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (response_id) DO UPDATE SET
|
||||
previous_response_id = excluded.previous_response_id,
|
||||
model = excluded.model,
|
||||
status = excluded.status,
|
||||
created_at = excluded.created_at,
|
||||
input_messages = excluded.input_messages,
|
||||
output_items = excluded.output_items,
|
||||
usage = excluded.usage,
|
||||
instructions = excluded.instructions,
|
||||
error = excluded.error
|
||||
''', (
|
||||
response_id, previous_response_id, model, status, created_at,
|
||||
orjson.dumps(input_messages).decode("utf-8"),
|
||||
orjson.dumps(output_items).decode("utf-8") if output_items is not None else None,
|
||||
orjson.dumps(usage).decode("utf-8") if usage is not None else None,
|
||||
instructions,
|
||||
orjson.dumps(error).decode("utf-8") if error is not None else None,
|
||||
))
|
||||
await db.commit()
|
||||
|
||||
async def update_response_status(
|
||||
self,
|
||||
response_id: str,
|
||||
status: str,
|
||||
*,
|
||||
output_items: Optional[list] = None,
|
||||
usage: Optional[dict] = None,
|
||||
error: Optional[dict] = None,
|
||||
):
|
||||
"""Update the status (and optionally output/usage/error) of a stored response."""
|
||||
db = await self._get_connection()
|
||||
async with self._operation_lock:
|
||||
await db.execute('''
|
||||
UPDATE stored_responses
|
||||
SET status = ?,
|
||||
output_items = COALESCE(?, output_items),
|
||||
usage = COALESCE(?, usage),
|
||||
error = COALESCE(?, error)
|
||||
WHERE response_id = ?
|
||||
''', (
|
||||
status,
|
||||
orjson.dumps(output_items).decode("utf-8") if output_items is not None else None,
|
||||
orjson.dumps(usage).decode("utf-8") if usage is not None else None,
|
||||
orjson.dumps(error).decode("utf-8") if error is not None else None,
|
||||
response_id,
|
||||
))
|
||||
await db.commit()
|
||||
|
||||
async def get_response(self, response_id: str) -> Optional[dict]:
|
||||
"""Return a stored response as a dict, or None if not found."""
|
||||
db = await self._get_connection()
|
||||
async with self._operation_lock:
|
||||
async with db.execute('''
|
||||
SELECT response_id, previous_response_id, model, status, created_at,
|
||||
input_messages, output_items, usage, instructions, error
|
||||
FROM stored_responses WHERE response_id = ?
|
||||
''', (response_id,)) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return self._row_to_response(row) if row is not None else None
|
||||
|
||||
async def delete_response(self, response_id: str) -> bool:
|
||||
"""Delete a stored response. Returns True if a row was removed."""
|
||||
db = await self._get_connection()
|
||||
async with self._operation_lock:
|
||||
cursor = await db.execute(
|
||||
'DELETE FROM stored_responses WHERE response_id = ?', (response_id,)
|
||||
)
|
||||
await db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
async def get_response_chain(self, response_id: str, max_turns: int = 50) -> list:
|
||||
"""Walk previous_response_id back to the root, returned oldest-first.
|
||||
|
||||
Bounded to ``max_turns`` so a pathological chain cannot stall a request.
|
||||
Missing links terminate the walk gracefully.
|
||||
"""
|
||||
chain: list = []
|
||||
seen: set = set()
|
||||
current = response_id
|
||||
while current and current not in seen and len(chain) < max_turns:
|
||||
seen.add(current)
|
||||
resp = await self.get_response(current)
|
||||
if resp is None:
|
||||
break
|
||||
chain.append(resp)
|
||||
current = resp.get('previous_response_id')
|
||||
chain.reverse()
|
||||
return chain
|
||||
|
||||
async def fail_orphaned_responses(self) -> int:
|
||||
"""Mark non-terminal responses as failed (called on startup).
|
||||
|
||||
A background task lives in a worker's event loop; a process restart loses
|
||||
it while the DB row stays ``queued``/``in_progress`` forever. Reconcile
|
||||
those to ``failed`` so polling clients get a terminal state.
|
||||
"""
|
||||
db = await self._get_connection()
|
||||
async with self._operation_lock:
|
||||
cursor = await db.execute('''
|
||||
UPDATE stored_responses
|
||||
SET status = 'failed',
|
||||
error = ?
|
||||
WHERE status IN ('queued', 'in_progress')
|
||||
''', (orjson.dumps({"message": "Response interrupted by server restart", "type": "server_error"}).decode("utf-8"),))
|
||||
await db.commit()
|
||||
return cursor.rowcount
|
||||
|
|
|
|||
|
|
@ -78,6 +78,37 @@ endpoints:
|
|||
- OpenAI-compatible endpoints use `/v1` prefix
|
||||
- The router automatically detects endpoint type based on URL pattern
|
||||
|
||||
### `llama_server_endpoints`
|
||||
|
||||
**Type**: `list[str]` (optional)
|
||||
|
||||
**Default**: `[]`
|
||||
|
||||
**Description**: List of [llama.cpp `llama-server`](https://github.com/ggml-org/llama.cpp) endpoints (OpenAI-compatible, configured with the `/v1` suffix). The router reads each backend's loaded models from `/v1/models` (entries with `status == "loaded"`) and unloads idle models via `POST /models/unload`.
|
||||
|
||||
```yaml
|
||||
llama_server_endpoints:
|
||||
- http://192.168.0.50:8889/v1
|
||||
```
|
||||
|
||||
### `llama_swap_endpoints`
|
||||
|
||||
**Type**: `list[str]` (optional)
|
||||
|
||||
**Default**: `[]`
|
||||
|
||||
**Description**: List of [llama-swap](https://github.com/mostlygeek/llama-swap) endpoints (OpenAI-compatible, configured with the `/v1` suffix). llama-swap fronts multiple `llama-server` workers behind one address. It is treated like `llama_server_endpoints` for routing, model discovery, and reranking, but differs in two ways the router handles automatically:
|
||||
|
||||
- **Loaded-model detection** — llama-swap's `/v1/models` omits the per-model `status` field, so running workers are read from `GET /running` (entries with `state == "ready"`).
|
||||
- **Model unload** — done via `POST /api/models/unload/:model_id` (path parameter), not the `llama-server` body form.
|
||||
|
||||
The router also exposes a passthrough route, `GET|POST /upstream/:model_id/<path>`, which forwards directly to a model's underlying `llama-server` worker (via llama-swap's `/upstream`), letting clients use `llama-server` features that llama-swap does not forward (e.g. token-array prompts).
|
||||
|
||||
```yaml
|
||||
llama_swap_endpoints:
|
||||
- http://192.168.0.50:8890/v1
|
||||
```
|
||||
|
||||
### `max_concurrent_connections`
|
||||
|
||||
**Type**: `int`
|
||||
|
|
|
|||
492
requests/responses.py
Normal file
492
requests/responses.py
Normal file
|
|
@ -0,0 +1,492 @@
|
|||
"""Translation between the OpenAI **Responses API** and **Chat Completions**.
|
||||
|
||||
The router speaks Chat Completions to every backend (Ollama, llama-server,
|
||||
external OpenAI). To expose ``/v1/responses`` transparently on top of that, this
|
||||
module converts in both directions:
|
||||
|
||||
* request: Responses ``input`` / ``instructions`` / ``tools`` → chat ``messages`` / ``tools``
|
||||
* response: chat ``choices[0].message`` → Responses ``output`` items
|
||||
* stream: chat completion deltas → Responses typed SSE events
|
||||
|
||||
Pure functions / a stream-translator class — no I/O, mirroring the style of
|
||||
``requests/messages.py``. The native passthrough path (external OpenAI) does not
|
||||
use this module; it forwards the SDK's Responses objects directly.
|
||||
"""
|
||||
import secrets
|
||||
import time
|
||||
|
||||
import orjson
|
||||
|
||||
from requests.messages import _accumulate_openai_tc_delta
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request direction: Responses → Chat Completions
|
||||
# ---------------------------------------------------------------------------
|
||||
def _responses_content_to_chat(content):
|
||||
"""Convert a Responses message ``content`` into Chat Completions content.
|
||||
|
||||
Collapses a single text part to a plain string (what most backends expect);
|
||||
keeps a multimodal list otherwise.
|
||||
"""
|
||||
if content is None or isinstance(content, str):
|
||||
return content
|
||||
if not isinstance(content, list):
|
||||
return str(content)
|
||||
parts = []
|
||||
for p in content:
|
||||
if not isinstance(p, dict):
|
||||
parts.append({"type": "text", "text": str(p)})
|
||||
continue
|
||||
ptype = p.get("type")
|
||||
if ptype in ("input_text", "output_text", "text"):
|
||||
parts.append({"type": "text", "text": p.get("text", "")})
|
||||
elif ptype in ("input_image", "image_url"):
|
||||
url = p.get("image_url")
|
||||
if isinstance(url, dict):
|
||||
url = url.get("url")
|
||||
if url:
|
||||
parts.append({"type": "image_url", "image_url": {"url": url}})
|
||||
# input_file / refusal / reasoning parts have no chat equivalent → skip
|
||||
if len(parts) == 1 and parts[0].get("type") == "text":
|
||||
return parts[0]["text"]
|
||||
return parts
|
||||
|
||||
|
||||
def _input_item_to_message(item):
|
||||
"""Convert a single Responses ``input`` item to a chat message (or None)."""
|
||||
if isinstance(item, str):
|
||||
return {"role": "user", "content": item}
|
||||
if not isinstance(item, dict):
|
||||
return None
|
||||
|
||||
itype = item.get("type")
|
||||
|
||||
if itype == "function_call":
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": item.get("call_id") or item.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item.get("name"),
|
||||
"arguments": item.get("arguments", ""),
|
||||
},
|
||||
}],
|
||||
}
|
||||
|
||||
if itype == "function_call_output":
|
||||
output = item.get("output", "")
|
||||
if not isinstance(output, str):
|
||||
output = orjson.dumps(output).decode("utf-8")
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": item.get("call_id") or item.get("id"),
|
||||
"content": output,
|
||||
}
|
||||
|
||||
if itype in ("reasoning",):
|
||||
# No Chat Completions equivalent — drop.
|
||||
return None
|
||||
|
||||
# "message" item or a bare {role, content} chat-style item
|
||||
role = item.get("role")
|
||||
if role is None:
|
||||
return None
|
||||
return {"role": role, "content": _responses_content_to_chat(item.get("content"))}
|
||||
|
||||
|
||||
def responses_input_to_messages(input_data, instructions=None):
|
||||
"""Build a Chat Completions ``messages`` list from Responses ``input``.
|
||||
|
||||
``instructions`` becomes a leading system message; a string ``input`` becomes
|
||||
a single user message; a list ``input`` is mapped item-by-item.
|
||||
"""
|
||||
messages = []
|
||||
if instructions:
|
||||
messages.append({"role": "system", "content": instructions})
|
||||
if input_data is None:
|
||||
return messages
|
||||
if isinstance(input_data, str):
|
||||
messages.append({"role": "user", "content": input_data})
|
||||
return messages
|
||||
if isinstance(input_data, list):
|
||||
for item in input_data:
|
||||
msg = _input_item_to_message(item)
|
||||
if msg is not None:
|
||||
messages.append(msg)
|
||||
return messages
|
||||
|
||||
|
||||
def _chat_content_to_responses_parts(content, assistant=False):
|
||||
"""Convert chat message content → Responses content parts."""
|
||||
text_type = "output_text" if assistant else "input_text"
|
||||
if content is None:
|
||||
return []
|
||||
if isinstance(content, str):
|
||||
return [{"type": text_type, "text": content}]
|
||||
parts = []
|
||||
for p in content if isinstance(content, list) else []:
|
||||
if not isinstance(p, dict):
|
||||
parts.append({"type": text_type, "text": str(p)})
|
||||
elif p.get("type") == "text":
|
||||
parts.append({"type": text_type, "text": p.get("text", "")})
|
||||
elif p.get("type") == "image_url":
|
||||
url = (p.get("image_url") or {}).get("url")
|
||||
if url:
|
||||
parts.append({"type": "input_image", "image_url": url})
|
||||
return parts
|
||||
|
||||
|
||||
def messages_to_responses_input(messages):
|
||||
"""Convert chat messages → ``(instructions, Responses input items)``.
|
||||
|
||||
Used for the native passthrough path: history that the router has resolved in
|
||||
chat-message space is re-expressed as Responses ``input``. Leading/standalone
|
||||
system messages are merged into ``instructions``.
|
||||
"""
|
||||
instructions_parts = []
|
||||
items = []
|
||||
for m in messages:
|
||||
role = m.get("role")
|
||||
if role == "system":
|
||||
c = m.get("content")
|
||||
instructions_parts.append(c if isinstance(c, str) else orjson.dumps(c).decode("utf-8"))
|
||||
continue
|
||||
if role == "tool":
|
||||
out = m.get("content")
|
||||
if not isinstance(out, str):
|
||||
out = orjson.dumps(out).decode("utf-8")
|
||||
items.append({"type": "function_call_output",
|
||||
"call_id": m.get("tool_call_id"), "output": out})
|
||||
continue
|
||||
if role == "assistant" and m.get("tool_calls"):
|
||||
for tc in m["tool_calls"]:
|
||||
fn = tc.get("function", {})
|
||||
items.append({"type": "function_call", "call_id": tc.get("id"),
|
||||
"name": fn.get("name"), "arguments": fn.get("arguments", "")})
|
||||
if m.get("content"):
|
||||
items.append({"role": "assistant",
|
||||
"content": _chat_content_to_responses_parts(m["content"], assistant=True)})
|
||||
continue
|
||||
items.append({"role": role,
|
||||
"content": _chat_content_to_responses_parts(m.get("content"),
|
||||
assistant=(role == "assistant"))})
|
||||
instructions = "\n\n".join(p for p in instructions_parts if p) or None
|
||||
return instructions, items
|
||||
|
||||
|
||||
def responses_object_to_sse(resp):
|
||||
"""Render a *finished* Responses object as a valid SSE event stream.
|
||||
|
||||
Used to serve cache/store hits to streaming clients without a backend call.
|
||||
"""
|
||||
seq = [-1]
|
||||
|
||||
def ev(etype, payload):
|
||||
seq[0] += 1
|
||||
body = {"type": etype, "sequence_number": seq[0], **payload}
|
||||
return f"event: {etype}\ndata: {orjson.dumps(body).decode('utf-8')}\n\n".encode("utf-8")
|
||||
|
||||
parts_out = []
|
||||
in_progress = {**resp, "status": "in_progress", "output": [], "output_text": ""}
|
||||
parts_out.append(ev("response.created", {"response": in_progress}))
|
||||
parts_out.append(ev("response.in_progress", {"response": in_progress}))
|
||||
for oi, item in enumerate(resp.get("output", [])):
|
||||
parts_out.append(ev("response.output_item.added",
|
||||
{"output_index": oi, "item": {**item, "status": "in_progress"}}))
|
||||
if item.get("type") == "message":
|
||||
for ci, part in enumerate(item.get("content", [])):
|
||||
if part.get("type") == "output_text":
|
||||
iid = item.get("id")
|
||||
parts_out.append(ev("response.content_part.added", {
|
||||
"item_id": iid, "output_index": oi, "content_index": ci,
|
||||
"part": {"type": "output_text", "text": "", "annotations": []}}))
|
||||
parts_out.append(ev("response.output_text.delta", {
|
||||
"item_id": iid, "output_index": oi, "content_index": ci,
|
||||
"delta": part.get("text", "")}))
|
||||
parts_out.append(ev("response.output_text.done", {
|
||||
"item_id": iid, "output_index": oi, "content_index": ci,
|
||||
"text": part.get("text", "")}))
|
||||
parts_out.append(ev("response.content_part.done", {
|
||||
"item_id": iid, "output_index": oi, "content_index": ci, "part": part}))
|
||||
parts_out.append(ev("response.output_item.done", {"output_index": oi, "item": item}))
|
||||
parts_out.append(ev("response.completed", {"response": resp}))
|
||||
return b"".join(parts_out)
|
||||
|
||||
|
||||
def tools_responses_to_chat(tools):
|
||||
"""Map Responses tool definitions (flattened) → Chat Completions (nested)."""
|
||||
if not tools:
|
||||
return None
|
||||
out = []
|
||||
for t in tools:
|
||||
if isinstance(t, dict) and t.get("type") == "function" and "function" not in t:
|
||||
fn = {k: t[k] for k in ("name", "description", "parameters", "strict") if k in t}
|
||||
out.append({"type": "function", "function": fn})
|
||||
else:
|
||||
out.append(t)
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response direction: Chat Completions → Responses
|
||||
# ---------------------------------------------------------------------------
|
||||
def _new_id(prefix):
|
||||
return f"{prefix}_{secrets.token_hex(16)}"
|
||||
|
||||
|
||||
def chat_message_to_output_items(message):
|
||||
"""Convert an assistant chat message (dict) into Responses output items."""
|
||||
items = []
|
||||
content = message.get("content")
|
||||
if content:
|
||||
items.append({
|
||||
"type": "message",
|
||||
"id": _new_id("msg"),
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content, "annotations": []}],
|
||||
})
|
||||
for tc in message.get("tool_calls") or []:
|
||||
fn = tc.get("function", {})
|
||||
items.append({
|
||||
"type": "function_call",
|
||||
"id": _new_id("fc"),
|
||||
"call_id": tc.get("id"),
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments", ""),
|
||||
"status": "completed",
|
||||
})
|
||||
return items
|
||||
|
||||
|
||||
def usage_chat_to_responses(usage):
|
||||
"""Map chat usage ``{prompt_tokens, completion_tokens}`` → Responses usage."""
|
||||
if not usage:
|
||||
return None
|
||||
prompt = usage.get("prompt_tokens") or 0
|
||||
completion = usage.get("completion_tokens") or 0
|
||||
return {
|
||||
"input_tokens": prompt,
|
||||
"output_tokens": completion,
|
||||
"total_tokens": usage.get("total_tokens") or (prompt + completion),
|
||||
}
|
||||
|
||||
|
||||
def output_items_to_text(output_items):
|
||||
"""Concatenate the ``output_text`` parts of all message items."""
|
||||
chunks = []
|
||||
for item in output_items or []:
|
||||
if item.get("type") != "message":
|
||||
continue
|
||||
for part in item.get("content") or []:
|
||||
if part.get("type") == "output_text":
|
||||
chunks.append(part.get("text", ""))
|
||||
return "".join(chunks)
|
||||
|
||||
|
||||
def build_response_object(
|
||||
*,
|
||||
response_id,
|
||||
model,
|
||||
output_items=None,
|
||||
usage=None,
|
||||
status="completed",
|
||||
created_at=None,
|
||||
previous_response_id=None,
|
||||
instructions=None,
|
||||
error=None,
|
||||
metadata=None,
|
||||
):
|
||||
"""Assemble a full ``object:"response"`` body for a non-streaming reply."""
|
||||
output_items = output_items or []
|
||||
return {
|
||||
"id": response_id,
|
||||
"object": "response",
|
||||
"created_at": created_at or int(time.time()),
|
||||
"status": status,
|
||||
"model": model,
|
||||
"output": output_items,
|
||||
"output_text": output_items_to_text(output_items),
|
||||
"instructions": instructions,
|
||||
"previous_response_id": previous_response_id,
|
||||
"usage": usage_chat_to_responses(usage) if usage and "input_tokens" not in usage else usage,
|
||||
"error": error,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming direction: Chat Completions deltas → Responses typed SSE events
|
||||
# ---------------------------------------------------------------------------
|
||||
class ChatToResponsesStream:
|
||||
"""Translate a Chat Completions streaming generator into Responses events.
|
||||
|
||||
Usage::
|
||||
|
||||
translator = ChatToResponsesStream(response_id, model, created_at)
|
||||
async for sse_bytes in translator.events(chat_async_gen):
|
||||
yield sse_bytes
|
||||
# translator.output_items / translator.usage now populated for storage
|
||||
|
||||
Emits the ordered event family
|
||||
``response.created`` → ``response.in_progress`` →
|
||||
(``response.output_item.added`` → ``response.content_part.added`` →
|
||||
``response.output_text.delta``* → ``response.output_text.done`` →
|
||||
``response.content_part.done`` → ``response.output_item.done``) and/or
|
||||
function-call item events → ``response.completed`` (carrying usage).
|
||||
"""
|
||||
|
||||
def __init__(self, response_id, model, created_at=None,
|
||||
previous_response_id=None, instructions=None, metadata=None):
|
||||
self.response_id = response_id
|
||||
self.model = model
|
||||
self.created_at = created_at or int(time.time())
|
||||
self.previous_response_id = previous_response_id
|
||||
self.instructions = instructions
|
||||
self.metadata = metadata or {}
|
||||
self.seq = -1
|
||||
self.output_items = []
|
||||
self.usage = None
|
||||
|
||||
def _snapshot(self, status, output=None):
|
||||
return build_response_object(
|
||||
response_id=self.response_id,
|
||||
model=self.model,
|
||||
output_items=output if output is not None else [],
|
||||
usage=self.usage,
|
||||
status=status,
|
||||
created_at=self.created_at,
|
||||
previous_response_id=self.previous_response_id,
|
||||
instructions=self.instructions,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
def _event(self, etype, payload):
|
||||
self.seq += 1
|
||||
body = {"type": etype, "sequence_number": self.seq, **payload}
|
||||
return f"event: {etype}\ndata: {orjson.dumps(body).decode('utf-8')}\n\n".encode("utf-8")
|
||||
|
||||
async def events(self, async_gen):
|
||||
yield self._event("response.created", {"response": self._snapshot("in_progress")})
|
||||
yield self._event("response.in_progress", {"response": self._snapshot("in_progress")})
|
||||
|
||||
next_oi = 0
|
||||
# text message state
|
||||
msg_item_id = None
|
||||
msg_oi = None
|
||||
text_parts = []
|
||||
# function-call state, keyed by chat tool_call index
|
||||
tc_state = {} # idx -> {oi, item_id, call_id, name, args}
|
||||
|
||||
async for chunk in async_gen:
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage is not None:
|
||||
self.usage = {
|
||||
"prompt_tokens": usage.prompt_tokens or 0,
|
||||
"completion_tokens": usage.completion_tokens or 0,
|
||||
}
|
||||
choices = getattr(chunk, "choices", None)
|
||||
if not choices:
|
||||
continue
|
||||
delta = choices[0].delta
|
||||
|
||||
content_piece = getattr(delta, "content", None)
|
||||
if content_piece:
|
||||
if msg_item_id is None:
|
||||
msg_item_id = _new_id("msg")
|
||||
msg_oi = next_oi
|
||||
next_oi += 1
|
||||
item = {
|
||||
"id": msg_item_id, "type": "message", "status": "in_progress",
|
||||
"role": "assistant", "content": [],
|
||||
}
|
||||
yield self._event("response.output_item.added",
|
||||
{"output_index": msg_oi, "item": item})
|
||||
yield self._event("response.content_part.added", {
|
||||
"item_id": msg_item_id, "output_index": msg_oi, "content_index": 0,
|
||||
"part": {"type": "output_text", "text": "", "annotations": []},
|
||||
})
|
||||
text_parts.append(content_piece)
|
||||
yield self._event("response.output_text.delta", {
|
||||
"item_id": msg_item_id, "output_index": msg_oi, "content_index": 0,
|
||||
"delta": content_piece,
|
||||
})
|
||||
|
||||
for tc in getattr(delta, "tool_calls", None) or []:
|
||||
idx = tc.index
|
||||
fn = getattr(tc, "function", None)
|
||||
if idx not in tc_state:
|
||||
item_id = _new_id("fc")
|
||||
state = {
|
||||
"oi": next_oi, "item_id": item_id,
|
||||
"call_id": getattr(tc, "id", None) or _new_id("call"),
|
||||
"name": (fn.name if fn else None), "args": "",
|
||||
}
|
||||
next_oi += 1
|
||||
tc_state[idx] = state
|
||||
yield self._event("response.output_item.added", {
|
||||
"output_index": state["oi"],
|
||||
"item": {
|
||||
"id": item_id, "type": "function_call", "status": "in_progress",
|
||||
"call_id": state["call_id"], "name": state["name"], "arguments": "",
|
||||
},
|
||||
})
|
||||
else:
|
||||
state = tc_state[idx]
|
||||
if getattr(tc, "id", None):
|
||||
state["call_id"] = tc.id
|
||||
if fn and fn.name:
|
||||
state["name"] = fn.name
|
||||
if fn and fn.arguments:
|
||||
state["args"] += fn.arguments
|
||||
yield self._event("response.function_call_arguments.delta", {
|
||||
"item_id": state["item_id"], "output_index": state["oi"],
|
||||
"delta": fn.arguments,
|
||||
})
|
||||
|
||||
# finalize message item
|
||||
if msg_item_id is not None:
|
||||
full_text = "".join(text_parts)
|
||||
yield self._event("response.output_text.done", {
|
||||
"item_id": msg_item_id, "output_index": msg_oi, "content_index": 0,
|
||||
"text": full_text,
|
||||
})
|
||||
done_part = {"type": "output_text", "text": full_text, "annotations": []}
|
||||
yield self._event("response.content_part.done", {
|
||||
"item_id": msg_item_id, "output_index": msg_oi, "content_index": 0,
|
||||
"part": done_part,
|
||||
})
|
||||
msg_item = {
|
||||
"id": msg_item_id, "type": "message", "status": "completed",
|
||||
"role": "assistant", "content": [done_part],
|
||||
}
|
||||
yield self._event("response.output_item.done",
|
||||
{"output_index": msg_oi, "item": msg_item})
|
||||
|
||||
# finalize function-call items (in output-index order)
|
||||
tc_items = {}
|
||||
for idx, state in tc_state.items():
|
||||
yield self._event("response.function_call_arguments.done", {
|
||||
"item_id": state["item_id"], "output_index": state["oi"],
|
||||
"arguments": state["args"],
|
||||
})
|
||||
fc_item = {
|
||||
"id": state["item_id"], "type": "function_call", "status": "completed",
|
||||
"call_id": state["call_id"], "name": state["name"], "arguments": state["args"],
|
||||
}
|
||||
tc_items[state["oi"]] = fc_item
|
||||
yield self._event("response.output_item.done",
|
||||
{"output_index": state["oi"], "item": fc_item})
|
||||
|
||||
# assemble final output items ordered by output index
|
||||
ordered = []
|
||||
if msg_item_id is not None:
|
||||
ordered.append((msg_oi, msg_item))
|
||||
ordered.extend(tc_items.items())
|
||||
self.output_items = [item for _, item in sorted(ordered, key=lambda kv: kv[0])]
|
||||
|
||||
yield self._event("response.completed",
|
||||
{"response": self._snapshot("completed", self.output_items)})
|
||||
15
router.py
15
router.py
|
|
@ -231,6 +231,7 @@ from backends.health import (
|
|||
from backends.normalize import (
|
||||
is_ext_openai_endpoint,
|
||||
is_openai_compatible,
|
||||
llama_endpoints,
|
||||
get_tracking_model,
|
||||
)
|
||||
|
||||
|
|
@ -290,6 +291,8 @@ from api.management import router as management_router
|
|||
app.include_router(management_router)
|
||||
from api.openai import router as openai_router
|
||||
app.include_router(openai_router)
|
||||
from api.responses import router as responses_router
|
||||
app.include_router(responses_router)
|
||||
from api.ollama import router as ollama_router
|
||||
app.include_router(ollama_router)
|
||||
|
||||
|
|
@ -308,6 +311,7 @@ async def startup_event() -> None:
|
|||
f"Loaded configuration from {config_path}:\n"
|
||||
f" endpoints={config.endpoints},\n"
|
||||
f" llama_server_endpoints={config.llama_server_endpoints},\n"
|
||||
f" llama_swap_endpoints={config.llama_swap_endpoints},\n"
|
||||
f" max_concurrent_connections={config.max_concurrent_connections},\n"
|
||||
f" endpoint_config={config.endpoint_config},\n"
|
||||
f" priority_routing={config.priority_routing}"
|
||||
|
|
@ -322,6 +326,13 @@ async def startup_event() -> None:
|
|||
db = TokenDatabase(config.db_path)
|
||||
await db.init_db()
|
||||
|
||||
# Reconcile Responses-API background tasks lost across a restart: their
|
||||
# in-memory asyncio task is gone but the DB row may still read queued /
|
||||
# in_progress, so mark those failed to give polling clients a terminal state.
|
||||
_orphaned = await db.fail_orphaned_responses()
|
||||
if _orphaned:
|
||||
print(f"[startup] Marked {_orphaned} orphaned background response(s) as failed.")
|
||||
|
||||
# Load existing token counts from database
|
||||
async for count_entry in db.load_token_counts():
|
||||
endpoint = count_entry['endpoint']
|
||||
|
|
@ -365,7 +376,7 @@ async def startup_event() -> None:
|
|||
app_state["httpx_clients"][ep] = httpx.AsyncClient(timeout=30.0)
|
||||
|
||||
# Create per-endpoint Unix socket sessions for .sock endpoints
|
||||
for ep in config.llama_server_endpoints:
|
||||
for ep in llama_endpoints(config):
|
||||
if _is_unix_socket_endpoint(ep):
|
||||
sock_path = _get_socket_path(ep)
|
||||
sock_connector = aiohttp.UnixConnector(path=sock_path)
|
||||
|
|
@ -382,7 +393,7 @@ async def startup_event() -> None:
|
|||
# client (/api/chat, /api/generate) and the OpenAI client (/v1/* routes),
|
||||
# so warm both; OpenAI-compatible endpoints only need the OpenAI client.
|
||||
_warm_endpoints = config.endpoints + [
|
||||
ep for ep in config.llama_server_endpoints if ep not in config.endpoints
|
||||
ep for ep in llama_endpoints(config) if ep not in config.endpoints
|
||||
]
|
||||
for ep in _warm_endpoints:
|
||||
try:
|
||||
|
|
|
|||
71
routing.py
71
routing.py
|
|
@ -32,6 +32,8 @@ from backends.health import _is_fresh
|
|||
from backends.normalize import (
|
||||
is_ext_openai_endpoint,
|
||||
is_openai_compatible,
|
||||
is_llama_server,
|
||||
llama_endpoints,
|
||||
get_tracking_model,
|
||||
)
|
||||
from backends.probe import fetch
|
||||
|
|
@ -93,8 +95,8 @@ async def choose_endpoint(model: str, reserve: bool = True,
|
|||
"""
|
||||
config = get_config()
|
||||
# 1️⃣ Gather advertised‑model sets for all endpoints concurrently
|
||||
# Include both config.endpoints and config.llama_server_endpoints
|
||||
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
|
||||
# Include config.endpoints plus any llama-server / llama-swap endpoints
|
||||
llama_eps_extra = [ep for ep in llama_endpoints(config) if ep not in config.endpoints]
|
||||
all_endpoints = config.endpoints + llama_eps_extra
|
||||
|
||||
tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if not is_openai_compatible(ep)]
|
||||
|
|
@ -114,7 +116,7 @@ async def choose_endpoint(model: str, reserve: bool = True,
|
|||
model_without_latest = model.split(":latest")[0]
|
||||
candidate_endpoints = [
|
||||
ep for ep, models in zip(all_endpoints, advertised_sets)
|
||||
if model_without_latest in models and (is_ext_openai_endpoint(ep) or ep in config.llama_server_endpoints)
|
||||
if model_without_latest in models and (is_ext_openai_endpoint(ep) or is_llama_server(ep))
|
||||
]
|
||||
if not candidate_endpoints:
|
||||
# Only add :latest suffix if model doesn't already have a version suffix
|
||||
|
|
@ -202,6 +204,43 @@ async def choose_endpoint(model: str, reserve: bool = True,
|
|||
def utilization_ratio(ep: str) -> float:
|
||||
return tracking_usage(ep) / get_max_connections(ep)
|
||||
|
||||
def total_load(ep: str) -> int:
|
||||
"""Sum of in-flight requests across *all* models on the endpoint."""
|
||||
return sum(usage_counts.get(ep, {}).values())
|
||||
|
||||
# How many models each candidate currently has *resident* (from the
|
||||
# /api/ps probe). With infinite keep-alive a model stays loaded long
|
||||
# after its in-flight count drops to zero, so this is the signal that
|
||||
# spreads *distinct* models across backends.
|
||||
ep_loaded_counts = {
|
||||
ep: len(models) for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
}
|
||||
|
||||
def loaded_count(ep: str) -> int:
|
||||
return ep_loaded_counts.get(ep, 0)
|
||||
|
||||
def pick_least_loaded(eps: list[str]) -> str:
|
||||
"""Pick the least-committed endpoint, breaking ties at random.
|
||||
|
||||
Ordering key is ``(total_load, loaded_count)``:
|
||||
|
||||
* ``total_load`` (in-flight requests across *all* models) keeps a
|
||||
request off a backend already busy with a *different* model —
|
||||
otherwise the per-model count reads zero everywhere and the
|
||||
ranking is discarded (cold model B landing on the box serving A).
|
||||
* ``loaded_count`` (number of *resident* models) then spreads
|
||||
distinct models across backends. Two different cold models (27b,
|
||||
35b) requested back-to-back must not pile onto the same box: once
|
||||
27b is resident there, that box has loaded_count 1 while the idle
|
||||
backends have 0, so the next cold model prefers an empty backend
|
||||
even though every backend reports zero in-flight load.
|
||||
|
||||
``random.choice`` only breaks genuine ties on both keys, so a single
|
||||
idle cluster still distributes the very first cold model evenly."""
|
||||
best = min((total_load(ep), loaded_count(ep)) for ep in eps)
|
||||
tied = [ep for ep in eps if (total_load(ep), loaded_count(ep)) == best]
|
||||
return random.choice(tied)
|
||||
|
||||
# Priority map: position in all_endpoints list (lower = higher priority)
|
||||
ep_priority = {ep: i for i, ep in enumerate(all_endpoints)}
|
||||
|
||||
|
|
@ -235,15 +274,11 @@ async def choose_endpoint(model: str, reserve: bool = True,
|
|||
loaded_and_free.sort(key=utilization_ratio)
|
||||
selected = loaded_and_free[0]
|
||||
else:
|
||||
# Sort ascending for load balancing — all endpoints here already have the
|
||||
# model loaded, so there is no model-switching cost to optimise for.
|
||||
loaded_and_free.sort(key=tracking_usage)
|
||||
# When all candidates are equally idle, randomise to avoid always picking
|
||||
# the first entry in a stable sort.
|
||||
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
|
||||
selected = random.choice(loaded_and_free)
|
||||
else:
|
||||
selected = loaded_and_free[0]
|
||||
# All endpoints here already have the model loaded, so there
|
||||
# is no model-switching cost to optimise for. Pick the least
|
||||
# *total*-loaded one (tie broken at random) so we steer away
|
||||
# from a backend busy serving other models.
|
||||
selected = pick_least_loaded(loaded_and_free)
|
||||
else:
|
||||
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
||||
endpoints_with_free_slot = [
|
||||
|
|
@ -257,14 +292,10 @@ async def choose_endpoint(model: str, reserve: bool = True,
|
|||
endpoints_with_free_slot.sort(key=utilization_ratio)
|
||||
selected = endpoints_with_free_slot[0]
|
||||
else:
|
||||
# Sort by total endpoint load (ascending) to prefer idle endpoints.
|
||||
endpoints_with_free_slot.sort(
|
||||
key=lambda ep: sum(usage_counts.get(ep, {}).values())
|
||||
)
|
||||
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
|
||||
selected = random.choice(endpoints_with_free_slot)
|
||||
else:
|
||||
selected = endpoints_with_free_slot[0]
|
||||
# Prefer the endpoint with the lowest *total* load so the
|
||||
# cold-start cost lands on genuinely idle hardware rather
|
||||
# than a backend already busy with a different model.
|
||||
selected = pick_least_loaded(endpoints_with_free_slot)
|
||||
else:
|
||||
# 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
|
||||
if config.priority_routing:
|
||||
|
|
|
|||
|
|
@ -4,10 +4,14 @@ endpoints:
|
|||
llama_server_endpoints:
|
||||
- http://192.168.0.51:12434/v1
|
||||
|
||||
llama_swap_endpoints:
|
||||
- http://192.168.0.51:12435/v1
|
||||
|
||||
max_concurrent_connections: 2
|
||||
|
||||
api_keys:
|
||||
"http://192.168.0.51:12434": "ollama"
|
||||
"http://192.168.0.51:12434/v1": "llama"
|
||||
"http://192.168.0.51:12435/v1": "llama-swap"
|
||||
|
||||
cache_enabled: false
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ def mock_config():
|
|||
cfg = MagicMock()
|
||||
cfg.endpoints = [TEST_OLLAMA]
|
||||
cfg.llama_server_endpoints = [TEST_LLAMA]
|
||||
cfg.llama_swap_endpoints = []
|
||||
cfg.api_keys = {TEST_OLLAMA: "ollama", TEST_LLAMA: "llama"}
|
||||
cfg.max_concurrent_connections = 2
|
||||
cfg.router_api_key = None
|
||||
|
|
@ -70,6 +71,7 @@ def mock_config_no_llama():
|
|||
cfg = MagicMock()
|
||||
cfg.endpoints = [TEST_OLLAMA]
|
||||
cfg.llama_server_endpoints = []
|
||||
cfg.llama_swap_endpoints = []
|
||||
cfg.api_keys = {TEST_OLLAMA: "ollama"}
|
||||
cfg.max_concurrent_connections = 2
|
||||
cfg.router_api_key = None
|
||||
|
|
@ -83,6 +85,7 @@ def mock_config_with_key():
|
|||
cfg = MagicMock()
|
||||
cfg.endpoints = [TEST_OLLAMA]
|
||||
cfg.llama_server_endpoints = []
|
||||
cfg.llama_swap_endpoints = []
|
||||
cfg.api_keys = {}
|
||||
cfg.max_concurrent_connections = 2
|
||||
cfg.router_api_key = "test-secret-key"
|
||||
|
|
|
|||
|
|
@ -12,10 +12,11 @@ EP3 = "http://ep3:11434"
|
|||
LLAMA_EP = "http://llama:8080/v1"
|
||||
|
||||
|
||||
def _make_cfg(endpoints, llama_eps=None, max_conn=2, endpoint_config=None, priority_routing=False):
|
||||
def _make_cfg(endpoints, llama_eps=None, swap_eps=None, max_conn=2, endpoint_config=None, priority_routing=False):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = endpoints
|
||||
cfg.llama_server_endpoints = llama_eps or []
|
||||
cfg.llama_swap_endpoints = swap_eps or []
|
||||
cfg.api_keys = {}
|
||||
cfg.max_concurrent_connections = max_conn
|
||||
cfg.endpoint_config = endpoint_config or {}
|
||||
|
|
@ -46,6 +47,27 @@ class TestChooseEndpointBasic:
|
|||
assert ep == EP1
|
||||
assert tracking == "llama3.2:latest"
|
||||
|
||||
async def test_llama_swap_endpoint_is_a_candidate(self):
|
||||
swap_ep = "http://swap:8080/v1"
|
||||
cfg = _make_cfg([EP1], swap_eps=[swap_ep])
|
||||
|
||||
async def available(ep, *_):
|
||||
# Only the llama-swap backend advertises this model
|
||||
return {"org/model:Q4_K_M"} if ep == swap_ep else set()
|
||||
|
||||
async def loaded(ep):
|
||||
return {"org/model:Q4_K_M"} if ep == swap_ep else set()
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", side_effect=available),
|
||||
patch.object(router.fetch, "loaded_models", side_effect=loaded),
|
||||
):
|
||||
ep, tracking = await router.choose_endpoint("org/model:Q4_K_M")
|
||||
assert ep == swap_ep
|
||||
# llama-swap models are tracked under their normalized name
|
||||
assert tracking == "model"
|
||||
|
||||
async def test_raises_when_no_endpoint_has_model(self):
|
||||
cfg = _make_cfg([EP1, EP2])
|
||||
with (
|
||||
|
|
@ -85,6 +107,64 @@ class TestChooseEndpointBasic:
|
|||
ep, _ = await router.choose_endpoint("llama3.2:latest")
|
||||
assert ep in (EP1, EP2)
|
||||
|
||||
async def test_cold_model_avoids_backend_busy_with_other_model(self):
|
||||
# Regression: heterogeneous cluster. A cold model B (loaded nowhere)
|
||||
# must not be routed to a backend already serving a *different* model
|
||||
# while other backends sit idle. The step-4 idle check used to look at
|
||||
# per-model usage (zero everywhere for B) and discard the total-load
|
||||
# ranking, so B could land on the busy backend at random.
|
||||
cfg = _make_cfg([EP1, EP2, EP3], max_conn=4)
|
||||
|
||||
async def available(ep, *_):
|
||||
return {"model-a:latest", "model-b:latest"}
|
||||
|
||||
# EP3 is busy with model A; EP1 and EP2 are completely idle. Model B
|
||||
# is loaded nowhere.
|
||||
router.usage_counts[EP3]["model-a:latest"] = 1
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", side_effect=available),
|
||||
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
|
||||
):
|
||||
# Run repeatedly: the busy backend must be excluded every time,
|
||||
# the idle two share the load at random.
|
||||
for _ in range(50):
|
||||
ep, _ = await router.choose_endpoint("model-b:latest", reserve=False)
|
||||
assert ep in (EP1, EP2)
|
||||
assert ep != EP3
|
||||
|
||||
async def test_two_cold_models_spread_across_backends(self):
|
||||
# Regression: 3 backends all advertise all models. Two *different*
|
||||
# cold models requested back-to-back must land on *different*
|
||||
# backends. Once model-a is resident on the chosen backend (infinite
|
||||
# keep-alive), its in-flight count drops back to 0 — so only the
|
||||
# resident-model count distinguishes the backends. Without it, the
|
||||
# second cold model would randomly re-collide on the busy backend.
|
||||
cfg = _make_cfg([EP1, EP2, EP3], max_conn=4)
|
||||
|
||||
async def available(ep, *_):
|
||||
return {"model-a:latest", "model-b:latest"}
|
||||
|
||||
# model-a finished loading on EP1 and stays resident; its request has
|
||||
# completed so EP1 has zero in-flight load, same as EP2/EP3.
|
||||
loaded = {EP1: {"model-a:latest"}, EP2: set(), EP3: set()}
|
||||
|
||||
async def loaded_models(ep):
|
||||
return loaded[ep]
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", side_effect=available),
|
||||
patch.object(router.fetch, "loaded_models", side_effect=loaded_models),
|
||||
):
|
||||
# A cold model-b must avoid EP1 (which already holds model-a) and
|
||||
# go to one of the empty backends, every time.
|
||||
for _ in range(50):
|
||||
ep, _ = await router.choose_endpoint("model-b:latest", reserve=False)
|
||||
assert ep in (EP2, EP3)
|
||||
assert ep != EP1
|
||||
|
||||
async def test_saturated_picks_least_busy(self):
|
||||
cfg = _make_cfg([EP1, EP2])
|
||||
cfg.max_concurrent_connections = 1
|
||||
|
|
|
|||
|
|
@ -20,10 +20,11 @@ MOCK_OLLAMA_EP = "http://mock-ollama:11434"
|
|||
MOCK_LLAMA_EP = "http://mock-llama:8080/v1"
|
||||
|
||||
|
||||
def _make_cfg(ollama_eps=None, llama_eps=None, api_keys=None):
|
||||
def _make_cfg(ollama_eps=None, llama_eps=None, swap_eps=None, api_keys=None):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = ollama_eps or [MOCK_OLLAMA_EP]
|
||||
cfg.llama_server_endpoints = llama_eps or [MOCK_LLAMA_EP]
|
||||
cfg.llama_swap_endpoints = swap_eps or []
|
||||
cfg.api_keys = api_keys or {}
|
||||
cfg.max_concurrent_connections = 2
|
||||
cfg.router_api_key = None
|
||||
|
|
@ -228,6 +229,30 @@ class TestFetchLoadedModels:
|
|||
models = await router.fetch.loaded_models(MOCK_LLAMA_EP)
|
||||
assert "always-on-model" in models
|
||||
|
||||
async def test_llama_swap_reads_running_state_ready(self):
|
||||
# llama-swap omits the /v1/models status field, so loaded workers come
|
||||
# from /running (a root route — the /v1 suffix must be stripped).
|
||||
swap_ep = "http://mock-swap:8080/v1"
|
||||
cfg = _make_cfg(llama_eps=[], swap_eps=[swap_ep])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(
|
||||
"http://mock-swap:8080/running",
|
||||
payload={"running": [
|
||||
{"model": "org/ready-model:Q4_K_M", "state": "ready"},
|
||||
{"model": "org/starting-model:Q8_0", "state": "starting"},
|
||||
]},
|
||||
)
|
||||
models = await router.fetch.loaded_models(swap_ep)
|
||||
assert models == {"org/ready-model:Q4_K_M"}
|
||||
|
||||
async def test_llama_swap_records_error_on_failure(self):
|
||||
swap_ep = "http://mock-swap:8080/v1"
|
||||
cfg = _make_cfg(llama_eps=[], swap_eps=[swap_ep])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get("http://mock-swap:8080/running", status=502, payload={})
|
||||
await router.fetch.loaded_models(swap_ep)
|
||||
assert swap_ep in router._loaded_error_cache
|
||||
|
||||
async def test_returns_empty_on_error(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
|
|
|
|||
131
test/test_llama_swap.py
Normal file
131
test/test_llama_swap.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
"""Tests for llama-swap specific behavior: unload dispatch + /upstream resolution."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import router
|
||||
import backends.control as control
|
||||
import api.openai as openai_api
|
||||
import api.ollama as ollama_api
|
||||
|
||||
SWAP_EP = "http://swap:8080/v1"
|
||||
SERVER_EP = "http://server:8080/v1"
|
||||
|
||||
|
||||
def _cfg(*, server=None, swap=None, api_keys=None):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = []
|
||||
cfg.llama_server_endpoints = server or []
|
||||
cfg.llama_swap_endpoints = swap or []
|
||||
cfg.api_keys = api_keys or {}
|
||||
return cfg
|
||||
|
||||
|
||||
class _RecordingSession:
|
||||
"""Captures the most recent ``post`` call and returns a 200 response."""
|
||||
|
||||
def __init__(self, status=200):
|
||||
self.calls = []
|
||||
self._status = status
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
self.calls.append((url, kwargs))
|
||||
resp = MagicMock()
|
||||
resp.status = self._status
|
||||
|
||||
class _Ctx:
|
||||
async def __aenter__(self_):
|
||||
return resp
|
||||
|
||||
async def __aexit__(self_, *exc):
|
||||
return False
|
||||
|
||||
return _Ctx()
|
||||
|
||||
|
||||
class TestUnloadDispatch:
|
||||
async def test_llama_swap_uses_path_param(self):
|
||||
sess = _RecordingSession()
|
||||
cfg = _cfg(swap=[SWAP_EP])
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(control, "get_probe_session", lambda ep: sess),
|
||||
):
|
||||
ok = await control.unload_model(SWAP_EP, "org/model:Q4_K_M")
|
||||
assert ok is True
|
||||
url, kwargs = sess.calls[0]
|
||||
# /v1 stripped, model id is a path param, no JSON body
|
||||
assert url == "http://swap:8080/api/models/unload/org/model:Q4_K_M"
|
||||
assert kwargs.get("json") is None
|
||||
|
||||
async def test_llama_server_uses_body(self):
|
||||
sess = _RecordingSession()
|
||||
cfg = _cfg(server=[SERVER_EP])
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(control, "get_probe_session", lambda ep: sess),
|
||||
):
|
||||
ok = await control.unload_model(SERVER_EP, "org/model:Q4_K_M")
|
||||
assert ok is True
|
||||
url, kwargs = sess.calls[0]
|
||||
assert url == "http://server:8080/models/unload"
|
||||
assert kwargs.get("json") == {"model": "org/model:Q4_K_M"}
|
||||
|
||||
async def test_unload_failure_returns_false(self):
|
||||
sess = _RecordingSession(status=500)
|
||||
cfg = _cfg(swap=[SWAP_EP])
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(control, "get_probe_session", lambda ep: sess),
|
||||
):
|
||||
ok = await control.unload_model(SWAP_EP, "m")
|
||||
assert ok is False
|
||||
|
||||
|
||||
class TestUpstreamResolution:
|
||||
async def test_resolves_endpoint_that_advertises_model(self):
|
||||
cfg = _cfg(swap=[SWAP_EP])
|
||||
with (
|
||||
patch.object(openai_api, "get_config", lambda: cfg),
|
||||
patch.object(openai_api.fetch, "available_models",
|
||||
AsyncMock(return_value={"org/model:Q4_K_M"})),
|
||||
):
|
||||
ep = await openai_api._resolve_llama_swap_endpoint("org/model:Q4_K_M")
|
||||
assert ep == SWAP_EP
|
||||
|
||||
async def test_returns_none_when_unserved(self):
|
||||
cfg = _cfg(swap=[SWAP_EP])
|
||||
with (
|
||||
patch.object(openai_api, "get_config", lambda: cfg),
|
||||
patch.object(openai_api.fetch, "available_models",
|
||||
AsyncMock(return_value=set())),
|
||||
):
|
||||
ep = await openai_api._resolve_llama_swap_endpoint("missing")
|
||||
assert ep is None
|
||||
|
||||
async def test_returns_none_without_swap_endpoints(self):
|
||||
cfg = _cfg(swap=[])
|
||||
with patch.object(openai_api, "get_config", lambda: cfg):
|
||||
ep = await openai_api._resolve_llama_swap_endpoint("any")
|
||||
assert ep is None
|
||||
|
||||
|
||||
class TestCtxSizeFromCmd:
|
||||
"""ctx-size parsing from a /running worker's launch `cmd` string."""
|
||||
|
||||
def test_parses_long_flag(self):
|
||||
cmd = ("llama-server --port 5818\n -hf unsloth/gpt-oss-20b-GGUF:F16\n"
|
||||
" --ctx-size 131072\n --temp 1.0\n")
|
||||
assert ollama_api._ctx_size_from_cmd(cmd) == 131072
|
||||
|
||||
def test_parses_short_flag(self):
|
||||
assert ollama_api._ctx_size_from_cmd("llama-server -c 8192 --port 1") == 8192
|
||||
|
||||
def test_parses_equals_form(self):
|
||||
assert ollama_api._ctx_size_from_cmd("llama-server --ctx-size=4096") == 4096
|
||||
|
||||
def test_returns_none_when_absent(self):
|
||||
assert ollama_api._ctx_size_from_cmd("llama-server --port 5818") is None
|
||||
|
||||
def test_returns_none_for_empty(self):
|
||||
assert ollama_api._ctx_size_from_cmd("") is None
|
||||
460
test/test_responses.py
Normal file
460
test/test_responses.py
Normal file
|
|
@ -0,0 +1,460 @@
|
|||
"""Tests for the OpenAI Responses API support (api/responses.py + requests/responses.py).
|
||||
|
||||
Covers the pure translation layer, the translated (Ollama-style) and native
|
||||
(external-OpenAI) backend paths, conversation storage / chaining, background mode,
|
||||
and the retrieve / delete / cancel routes.
|
||||
"""
|
||||
import asyncio
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from types import SimpleNamespace as NS
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
import router
|
||||
from api import responses as api_responses
|
||||
from requests import responses as rt
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Pure translation unit tests (no app / no I/O)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestTranslationInputToMessages:
|
||||
def test_string_input(self):
|
||||
msgs = rt.responses_input_to_messages("hello")
|
||||
assert msgs == [{"role": "user", "content": "hello"}]
|
||||
|
||||
def test_instructions_become_system(self):
|
||||
msgs = rt.responses_input_to_messages("hi", instructions="be brief")
|
||||
assert msgs[0] == {"role": "system", "content": "be brief"}
|
||||
assert msgs[1] == {"role": "user", "content": "hi"}
|
||||
|
||||
def test_item_list_text_and_image(self):
|
||||
items = [{
|
||||
"type": "message", "role": "user",
|
||||
"content": [
|
||||
{"type": "input_text", "text": "describe"},
|
||||
{"type": "input_image", "image_url": "http://x/y.png"},
|
||||
],
|
||||
}]
|
||||
msgs = rt.responses_input_to_messages(items)
|
||||
assert msgs[0]["role"] == "user"
|
||||
assert msgs[0]["content"] == [
|
||||
{"type": "text", "text": "describe"},
|
||||
{"type": "image_url", "image_url": {"url": "http://x/y.png"}},
|
||||
]
|
||||
|
||||
def test_single_text_part_collapses_to_string(self):
|
||||
items = [{"type": "message", "role": "user",
|
||||
"content": [{"type": "input_text", "text": "yo"}]}]
|
||||
assert rt.responses_input_to_messages(items)[0]["content"] == "yo"
|
||||
|
||||
def test_function_call_roundtrip(self):
|
||||
items = [
|
||||
{"type": "function_call", "call_id": "c1", "name": "get", "arguments": "{\"x\":1}"},
|
||||
{"type": "function_call_output", "call_id": "c1", "output": "42"},
|
||||
]
|
||||
msgs = rt.responses_input_to_messages(items)
|
||||
assert msgs[0]["role"] == "assistant"
|
||||
assert msgs[0]["tool_calls"][0]["id"] == "c1"
|
||||
assert msgs[0]["tool_calls"][0]["function"]["name"] == "get"
|
||||
assert msgs[1] == {"role": "tool", "tool_call_id": "c1", "content": "42"}
|
||||
|
||||
|
||||
class TestTranslationResponseDirection:
|
||||
def test_chat_message_to_output_items_text(self):
|
||||
items = rt.chat_message_to_output_items({"role": "assistant", "content": "hi there"})
|
||||
assert len(items) == 1
|
||||
assert items[0]["type"] == "message"
|
||||
assert items[0]["content"][0] == {"type": "output_text", "text": "hi there", "annotations": []}
|
||||
|
||||
def test_chat_message_to_output_items_tool_call(self):
|
||||
items = rt.chat_message_to_output_items({
|
||||
"role": "assistant", "content": None,
|
||||
"tool_calls": [{"id": "c9", "function": {"name": "f", "arguments": "{}"}}],
|
||||
})
|
||||
assert items[0]["type"] == "function_call"
|
||||
assert items[0]["call_id"] == "c9"
|
||||
assert items[0]["name"] == "f"
|
||||
|
||||
def test_usage_mapping(self):
|
||||
u = rt.usage_chat_to_responses({"prompt_tokens": 7, "completion_tokens": 3})
|
||||
assert u == {"input_tokens": 7, "output_tokens": 3, "total_tokens": 10}
|
||||
|
||||
def test_build_response_object_output_text(self):
|
||||
items = rt.chat_message_to_output_items({"role": "assistant", "content": "abc"})
|
||||
obj = rt.build_response_object(response_id="resp_1", model="m", output_items=items)
|
||||
assert obj["object"] == "response"
|
||||
assert obj["output_text"] == "abc"
|
||||
assert obj["status"] == "completed"
|
||||
|
||||
def test_tools_responses_to_chat(self):
|
||||
tools = [{"type": "function", "name": "f", "description": "d", "parameters": {"type": "object"}}]
|
||||
chat_tools = rt.tools_responses_to_chat(tools)
|
||||
assert chat_tools == [{"type": "function",
|
||||
"function": {"name": "f", "description": "d",
|
||||
"parameters": {"type": "object"}}}]
|
||||
|
||||
def test_messages_to_responses_input(self):
|
||||
instr, items = rt.messages_to_responses_input([
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "yo"},
|
||||
])
|
||||
assert instr == "sys"
|
||||
assert items[0] == {"role": "user", "content": [{"type": "input_text", "text": "hi"}]}
|
||||
assert items[1] == {"role": "assistant", "content": [{"type": "output_text", "text": "yo"}]}
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Fakes for backend generators
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _fake_completion(content="hello world", usage=(3, 5)):
|
||||
msg = MagicMock()
|
||||
msg.model_dump.return_value = {"role": "assistant", "content": content}
|
||||
usage_obj = MagicMock()
|
||||
usage_obj.model_dump.return_value = {
|
||||
"prompt_tokens": usage[0], "completion_tokens": usage[1], "total_tokens": sum(usage)}
|
||||
return NS(choices=[NS(message=msg)], usage=usage_obj)
|
||||
|
||||
|
||||
def _chunk(content=None, tool_calls=None):
|
||||
return NS(choices=[NS(delta=NS(content=content, tool_calls=tool_calls),
|
||||
finish_reason=None)], usage=None)
|
||||
|
||||
|
||||
def _usage_chunk(p, c):
|
||||
return NS(choices=[], usage=NS(prompt_tokens=p, completion_tokens=c))
|
||||
|
||||
|
||||
def _text_chunks():
|
||||
async def _gen():
|
||||
yield _chunk(content="Hel")
|
||||
yield _chunk(content="lo")
|
||||
yield _usage_chunk(3, 5)
|
||||
return _gen()
|
||||
|
||||
|
||||
def _toolcall_chunks():
|
||||
tc0 = NS(index=0, id="call_1", function=NS(name="lookup", arguments='{"q":'))
|
||||
tc1 = NS(index=0, id=None, function=NS(name=None, arguments='"hi"}'))
|
||||
|
||||
async def _gen():
|
||||
yield _chunk(tool_calls=[tc0])
|
||||
yield _chunk(tool_calls=[tc1])
|
||||
yield _usage_chunk(4, 2)
|
||||
return _gen()
|
||||
|
||||
|
||||
class _FakeEvent:
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
|
||||
def model_dump(self):
|
||||
return self._data
|
||||
|
||||
|
||||
def _native_event_stream():
|
||||
async def _gen():
|
||||
yield _FakeEvent({"type": "response.created",
|
||||
"response": {"id": "resp_openai", "status": "in_progress", "output": []}})
|
||||
yield _FakeEvent({"type": "response.output_text.delta",
|
||||
"item_id": "msg_1", "output_index": 0, "delta": "hi"})
|
||||
yield _FakeEvent({"type": "response.completed", "response": {
|
||||
"id": "resp_openai", "status": "completed",
|
||||
"output": [{"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "hi"}]}],
|
||||
"usage": {"input_tokens": 2, "output_tokens": 1, "total_tokens": 3}}})
|
||||
return _gen()
|
||||
|
||||
|
||||
def _sse_events(text):
|
||||
"""Split an SSE body into a list of (event_type, data_dict)."""
|
||||
out = []
|
||||
for frame in text.strip().split("\n\n"):
|
||||
if not frame.strip():
|
||||
continue
|
||||
etype = data = None
|
||||
for line in frame.splitlines():
|
||||
if line.startswith("event: "):
|
||||
etype = line[len("event: "):]
|
||||
elif line.startswith("data: "):
|
||||
data = orjson.loads(line[len("data: "):])
|
||||
out.append((etype, data))
|
||||
return out
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _enter(*cms):
|
||||
"""Enter a variable number of context managers (works with *unpacked tuples)."""
|
||||
with ExitStack() as stack:
|
||||
for cm in cms:
|
||||
stack.enter_context(cm)
|
||||
yield
|
||||
|
||||
|
||||
def _patch_backend(native=False, endpoint="http://ollama:11434"):
|
||||
"""Context managers patching endpoint selection + client construction."""
|
||||
return (
|
||||
patch.object(api_responses, "choose_endpoint",
|
||||
AsyncMock(return_value=(endpoint, "test-model:latest"))),
|
||||
patch.object(api_responses, "decrement_usage", AsyncMock()),
|
||||
patch.object(api_responses, "is_ext_openai_endpoint", return_value=native),
|
||||
patch.object(api_responses, "_make_openai_client", return_value=MagicMock()),
|
||||
patch.object(api_responses, "get_llm_cache", return_value=None),
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Translated path (Ollama-style backend)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestTranslatedPath:
|
||||
async def test_nonstream(self, client):
|
||||
with _enter(*_patch_backend(native=False),
|
||||
patch.object(api_responses, "create_chat_with_retries",
|
||||
AsyncMock(return_value=_fake_completion("hello world")))):
|
||||
resp = await client.post("/v1/responses",
|
||||
json={"model": "test-model", "input": "hi", "store": False})
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["object"] == "response"
|
||||
assert body["output_text"] == "hello world"
|
||||
assert body["usage"] == {"input_tokens": 3, "output_tokens": 5, "total_tokens": 8}
|
||||
assert body["id"].startswith("resp_")
|
||||
|
||||
async def test_stream_event_sequence(self, client):
|
||||
with _enter(*_patch_backend(native=False),
|
||||
patch.object(api_responses, "create_chat_with_retries",
|
||||
AsyncMock(return_value=_text_chunks()))):
|
||||
resp = await client.post("/v1/responses",
|
||||
json={"model": "test-model", "input": "hi",
|
||||
"stream": True, "store": False})
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("text/event-stream")
|
||||
events = _sse_events(resp.content.decode())
|
||||
types = [e[0] for e in events]
|
||||
assert types[0] == "response.created"
|
||||
assert "response.output_text.delta" in types
|
||||
assert types[-1] == "response.completed"
|
||||
# concatenated deltas reconstruct the content
|
||||
deltas = "".join(d["delta"] for t, d in events if t == "response.output_text.delta")
|
||||
assert deltas == "Hello"
|
||||
# completed event carries usage
|
||||
completed = [d for t, d in events if t == "response.completed"][0]
|
||||
assert completed["response"]["usage"]["input_tokens"] == 3
|
||||
|
||||
async def test_stream_tool_calls(self, client):
|
||||
with _enter(*_patch_backend(native=False),
|
||||
patch.object(api_responses, "create_chat_with_retries",
|
||||
AsyncMock(return_value=_toolcall_chunks()))):
|
||||
resp = await client.post("/v1/responses",
|
||||
json={"model": "test-model", "input": "lookup hi",
|
||||
"stream": True, "store": False})
|
||||
events = _sse_events(resp.content.decode())
|
||||
types = [e[0] for e in events]
|
||||
assert "response.function_call_arguments.delta" in types
|
||||
assert "response.function_call_arguments.done" in types
|
||||
args = "".join(d["delta"] for t, d in events
|
||||
if t == "response.function_call_arguments.delta")
|
||||
assert args == '{"q":"hi"}'
|
||||
completed = [d for t, d in events if t == "response.completed"][0]
|
||||
fc = [i for i in completed["response"]["output"] if i["type"] == "function_call"][0]
|
||||
assert fc["name"] == "lookup"
|
||||
assert fc["arguments"] == '{"q":"hi"}'
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Native path (external OpenAI backend)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestNativePath:
|
||||
async def test_nonstream_passthrough_rewrites_id(self, client):
|
||||
oclient = MagicMock()
|
||||
resp_obj = MagicMock()
|
||||
resp_obj.model_dump.return_value = {
|
||||
"id": "resp_openai", "status": "completed",
|
||||
"output": [{"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "native hi"}]}],
|
||||
"usage": {"input_tokens": 2, "output_tokens": 3, "total_tokens": 5}}
|
||||
oclient.responses.create = AsyncMock(return_value=resp_obj)
|
||||
with (patch.object(api_responses, "choose_endpoint",
|
||||
AsyncMock(return_value=("https://api.openai.com/v1", "gpt"))),
|
||||
patch.object(api_responses, "decrement_usage", AsyncMock()),
|
||||
patch.object(api_responses, "is_ext_openai_endpoint", return_value=True),
|
||||
patch.object(api_responses, "_make_openai_client", return_value=oclient),
|
||||
patch.object(api_responses, "get_llm_cache", return_value=None)):
|
||||
resp = await client.post("/v1/responses",
|
||||
json={"model": "gpt", "input": "hi", "store": False})
|
||||
body = resp.json()
|
||||
assert body["output_text"] == "native hi"
|
||||
assert body["id"].startswith("resp_") and body["id"] != "resp_openai"
|
||||
# native call must not delegate state upstream
|
||||
assert oclient.responses.create.call_args.kwargs["store"] is False
|
||||
|
||||
async def test_stream_passthrough(self, client):
|
||||
oclient = MagicMock()
|
||||
oclient.responses.create = AsyncMock(return_value=_native_event_stream())
|
||||
with (patch.object(api_responses, "choose_endpoint",
|
||||
AsyncMock(return_value=("https://api.openai.com/v1", "gpt"))),
|
||||
patch.object(api_responses, "decrement_usage", AsyncMock()),
|
||||
patch.object(api_responses, "is_ext_openai_endpoint", return_value=True),
|
||||
patch.object(api_responses, "_make_openai_client", return_value=oclient),
|
||||
patch.object(api_responses, "get_llm_cache", return_value=None)):
|
||||
resp = await client.post("/v1/responses",
|
||||
json={"model": "gpt", "input": "hi",
|
||||
"stream": True, "store": False})
|
||||
events = _sse_events(resp.content.decode())
|
||||
# the completed event's response id is rewritten to the router id
|
||||
completed = [d for t, d in events if t == "response.completed"][0]
|
||||
assert completed["response"]["id"].startswith("resp_")
|
||||
assert completed["response"]["id"] != "resp_openai"
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Storage + chaining + retrieve/delete
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestStorageAndChaining:
|
||||
async def test_store_and_retrieve(self, client):
|
||||
with _enter(*_patch_backend(native=False),
|
||||
patch.object(api_responses, "create_chat_with_retries",
|
||||
AsyncMock(return_value=_fake_completion("remembered")))):
|
||||
created = await client.post("/v1/responses",
|
||||
json={"model": "test-model", "input": "hi", "store": True})
|
||||
rid = created.json()["id"]
|
||||
got = await client.get(f"/v1/responses/{rid}")
|
||||
assert got.status_code == 200
|
||||
assert got.json()["output_text"] == "remembered"
|
||||
|
||||
async def test_previous_response_id_rehydrates_history(self, client):
|
||||
# First turn
|
||||
with _enter(*_patch_backend(native=False),
|
||||
patch.object(api_responses, "create_chat_with_retries",
|
||||
AsyncMock(return_value=_fake_completion("turn-one")))):
|
||||
first = await client.post("/v1/responses",
|
||||
json={"model": "test-model", "input": "first?", "store": True})
|
||||
rid = first.json()["id"]
|
||||
|
||||
# Second turn references the first — capture the messages sent to the backend
|
||||
capture = AsyncMock(return_value=_fake_completion("turn-two"))
|
||||
with _enter(*_patch_backend(native=False),
|
||||
patch.object(api_responses, "create_chat_with_retries", capture)):
|
||||
await client.post("/v1/responses",
|
||||
json={"model": "test-model", "input": "second?",
|
||||
"previous_response_id": rid, "store": True})
|
||||
sent_messages = capture.call_args.args[1]["messages"]
|
||||
contents = [m.get("content") for m in sent_messages]
|
||||
assert "first?" in contents # prior user turn replayed
|
||||
assert "turn-one" in contents # prior assistant turn replayed
|
||||
assert "second?" in contents # current turn appended
|
||||
|
||||
async def test_delete(self, client):
|
||||
with _enter(*_patch_backend(native=False),
|
||||
patch.object(api_responses, "create_chat_with_retries",
|
||||
AsyncMock(return_value=_fake_completion("bye")))):
|
||||
created = await client.post("/v1/responses",
|
||||
json={"model": "test-model", "input": "hi", "store": True})
|
||||
rid = created.json()["id"]
|
||||
deleted = await client.delete(f"/v1/responses/{rid}")
|
||||
assert deleted.status_code == 200
|
||||
assert deleted.json()["deleted"] is True
|
||||
assert (await client.get(f"/v1/responses/{rid}")).status_code == 404
|
||||
|
||||
async def test_retrieve_missing_404(self, client):
|
||||
assert (await client.get("/v1/responses/resp_missing")).status_code == 404
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Background mode
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestBackgroundMode:
|
||||
async def test_background_requires_store(self, client):
|
||||
resp = await client.post("/v1/responses",
|
||||
json={"model": "test-model", "input": "hi",
|
||||
"background": True, "store": False})
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_background_lifecycle(self, client):
|
||||
with _enter(*_patch_backend(native=False),
|
||||
patch.object(api_responses, "create_chat_with_retries",
|
||||
AsyncMock(return_value=_fake_completion("bg-done")))):
|
||||
created = await client.post("/v1/responses",
|
||||
json={"model": "test-model", "input": "hi",
|
||||
"background": True, "store": True})
|
||||
assert created.status_code == 200
|
||||
assert created.json()["status"] == "queued"
|
||||
rid = created.json()["id"]
|
||||
# poll until terminal
|
||||
status = None
|
||||
for _ in range(100):
|
||||
await asyncio.sleep(0.01)
|
||||
got = await client.get(f"/v1/responses/{rid}")
|
||||
status = got.json()["status"]
|
||||
if status in ("completed", "failed", "cancelled"):
|
||||
break
|
||||
assert status == "completed"
|
||||
assert got.json()["output_text"] == "bg-done"
|
||||
|
||||
async def test_fail_orphaned_responses(self, client):
|
||||
db = router.db
|
||||
await db.store_response("resp_orphan", previous_response_id=None, model="m",
|
||||
status="in_progress", created_at=0, input_messages=[])
|
||||
n = await db.fail_orphaned_responses()
|
||||
assert n >= 1
|
||||
row = await db.get_response("resp_orphan")
|
||||
assert row["status"] == "failed"
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Cache parity
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class _FakeCache:
|
||||
def __init__(self, response_bytes):
|
||||
self._resp = response_bytes
|
||||
self.calls = []
|
||||
|
||||
async def get_chat(self, route, model, messages):
|
||||
self.calls.append((route, model, messages))
|
||||
return self._resp
|
||||
|
||||
|
||||
class TestCacheParity:
|
||||
async def test_cache_hit_served_as_response(self, client):
|
||||
cached = orjson.dumps(rt.build_response_object(
|
||||
response_id="resp_cached", model="test-model",
|
||||
output_items=rt.chat_message_to_output_items(
|
||||
{"role": "assistant", "content": "from-cache"})))
|
||||
fake = _FakeCache(cached)
|
||||
with (patch.object(api_responses, "get_llm_cache", return_value=fake),
|
||||
patch.object(api_responses, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached")))):
|
||||
resp = await client.post("/v1/responses",
|
||||
json={"model": "test-model", "input": "ping",
|
||||
"store": False, "nomyo": {"cache": True}})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["output_text"] == "from-cache"
|
||||
assert fake.calls and fake.calls[0][0] == "openai_responses"
|
||||
|
||||
async def test_cache_hit_served_as_sse(self, client):
|
||||
cached = orjson.dumps(rt.build_response_object(
|
||||
response_id="resp_cached", model="test-model",
|
||||
output_items=rt.chat_message_to_output_items(
|
||||
{"role": "assistant", "content": "from-cache"})))
|
||||
fake = _FakeCache(cached)
|
||||
with (patch.object(api_responses, "get_llm_cache", return_value=fake),
|
||||
patch.object(api_responses, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached")))):
|
||||
resp = await client.post("/v1/responses",
|
||||
json={"model": "test-model", "input": "ping",
|
||||
"stream": True, "store": False,
|
||||
"nomyo": {"cache": True}})
|
||||
assert resp.headers["content-type"].startswith("text/event-stream")
|
||||
events = _sse_events(resp.content.decode())
|
||||
deltas = "".join(d["delta"] for t, d in events if t == "response.output_text.delta")
|
||||
assert deltas == "from-cache"
|
||||
|
|
@ -277,3 +277,49 @@ class TestGetTrackingModel:
|
|||
with patch.object(router, "config", cfg):
|
||||
result = router.get_tracking_model(ep, "unsloth/model:Q8_0")
|
||||
assert result == "model"
|
||||
|
||||
|
||||
class TestLlamaSwapClassification:
|
||||
def _cfg(self, *, server=None, swap=None):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = []
|
||||
cfg.llama_server_endpoints = server or []
|
||||
cfg.llama_swap_endpoints = swap or []
|
||||
return cfg
|
||||
|
||||
def test_is_llama_swap_only_for_swap_list(self):
|
||||
from backends.normalize import is_llama_swap
|
||||
swap_ep = "http://host:8890/v1"
|
||||
server_ep = "http://host:8889/v1"
|
||||
cfg = self._cfg(server=[server_ep], swap=[swap_ep])
|
||||
with patch.object(router, "config", cfg):
|
||||
assert is_llama_swap(swap_ep) is True
|
||||
assert is_llama_swap(server_ep) is False
|
||||
|
||||
def test_is_llama_server_covers_both(self):
|
||||
from backends.normalize import is_llama_server
|
||||
swap_ep = "http://host:8890/v1"
|
||||
server_ep = "http://host:8889/v1"
|
||||
cfg = self._cfg(server=[server_ep], swap=[swap_ep])
|
||||
with patch.object(router, "config", cfg):
|
||||
assert is_llama_server(swap_ep) is True
|
||||
assert is_llama_server(server_ep) is True
|
||||
assert is_llama_server("http://host:11434") is False
|
||||
|
||||
def test_swap_is_openai_compatible_not_ext(self):
|
||||
swap_ep = "http://host:8890/v1"
|
||||
cfg = self._cfg(swap=[swap_ep])
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.is_openai_compatible(swap_ep) is True
|
||||
assert router.is_ext_openai_endpoint(swap_ep) is False
|
||||
|
||||
def test_swap_tracking_model_normalized(self):
|
||||
swap_ep = "http://host:8890/v1"
|
||||
cfg = self._cfg(swap=[swap_ep])
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.get_tracking_model(swap_ep, "unsloth/model:Q8_0") == "model"
|
||||
|
||||
def test_llama_endpoints_dedupes_and_orders(self):
|
||||
from backends.normalize import llama_endpoints
|
||||
cfg = self._cfg(server=["a", "b"], swap=["b", "c"])
|
||||
assert llama_endpoints(cfg) == ["a", "b", "c"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue