feat: add llama-swap as a backend
This commit is contained in:
parent
c8da58430a
commit
aa8baebac5
17 changed files with 544 additions and 52 deletions
|
|
@ -34,6 +34,8 @@ from backends.normalize import (
|
|||
ep2base,
|
||||
is_ext_openai_endpoint,
|
||||
is_openai_compatible,
|
||||
is_llama_server,
|
||||
llama_endpoints,
|
||||
_normalize_llama_model_name,
|
||||
)
|
||||
from backends.probe import fetch
|
||||
|
|
@ -353,7 +355,7 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
|
||||
send_params = {**params, "messages": resolved_msgs}
|
||||
# Proactive trim: only for small-ctx models we've already seen run out of space
|
||||
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
|
||||
_lookup_model = _normalize_llama_model_name(model) if is_llama_server(endpoint) else model
|
||||
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
|
||||
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2)
|
||||
|
|
@ -658,9 +660,9 @@ async def openai_models_proxy(request: Request):
|
|||
ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
|
||||
# 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models
|
||||
ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)]
|
||||
# 3. Query llama-server endpoints for loaded models via /v1/models
|
||||
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
||||
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
||||
# 3. Query llama-server / llama-swap endpoints for advertised models via /v1/models
|
||||
# Also query endpoints that may not be in config.endpoints
|
||||
all_llama_endpoints = llama_endpoints(config)
|
||||
llama_tasks = [
|
||||
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
|
||||
for ep in all_llama_endpoints
|
||||
|
|
@ -783,10 +785,10 @@ async def rerank_proxy(request: Request):
|
|||
upstream_payload[optional_key] = payload[optional_key]
|
||||
|
||||
# Determine upstream URL:
|
||||
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
|
||||
# llama-server / llama-swap expose /v1/rerank (base already contains /v1)
|
||||
# External OpenAI endpoints expose /rerank under their /v1 base
|
||||
if endpoint in config.llama_server_endpoints:
|
||||
# llama-server: endpoint may or may not already contain /v1
|
||||
if is_llama_server(endpoint):
|
||||
# llama-server / llama-swap: endpoint may or may not already contain /v1
|
||||
if "/v1" in endpoint:
|
||||
rerank_url = f"{endpoint}/rerank"
|
||||
else:
|
||||
|
|
@ -823,3 +825,82 @@ async def rerank_proxy(request: Request):
|
|||
return JSONResponse(content=data)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
|
||||
async def _resolve_llama_swap_endpoint(model_id: str) -> str | None:
|
||||
"""Pick the llama-swap endpoint that serves ``model_id``.
|
||||
|
||||
Prefers an endpoint that already has the worker running; falls back to any
|
||||
that advertises the model. Returns None if none do.
|
||||
"""
|
||||
config = get_config()
|
||||
swap_eps = config.llama_swap_endpoints
|
||||
if not swap_eps:
|
||||
return None
|
||||
|
||||
advertised = await asyncio.gather(
|
||||
*[fetch.available_models(ep, config.api_keys.get(ep)) for ep in swap_eps]
|
||||
)
|
||||
candidates = [ep for ep, models in zip(swap_eps, advertised) if model_id in models]
|
||||
if not candidates:
|
||||
return None
|
||||
if len(candidates) == 1:
|
||||
return candidates[0]
|
||||
|
||||
loaded = await asyncio.gather(*[fetch.loaded_models(ep) for ep in candidates])
|
||||
for ep, lm in zip(candidates, loaded):
|
||||
if model_id in lm:
|
||||
return ep
|
||||
return candidates[0]
|
||||
|
||||
|
||||
@router.api_route("/upstream/{model_id}/{path:path}", methods=["GET", "POST"])
|
||||
async def llama_swap_upstream(model_id: str, path: str, request: Request):
|
||||
"""Bypass llama-swap and reach a model's underlying llama-server worker directly
|
||||
via llama-swap's ``/upstream/:model_id`` route.
|
||||
|
||||
Lets clients use llama-server features that llama-swap itself does not forward
|
||||
(e.g. token-array prompts), while still letting the router pick the backend that
|
||||
actually hosts the model. ``/upstream`` is a root route, so the ``/v1`` suffix is
|
||||
stripped from the configured endpoint.
|
||||
"""
|
||||
config = get_config()
|
||||
endpoint = await _resolve_llama_swap_endpoint(model_id)
|
||||
if endpoint is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No configured llama-swap endpoint serves model '{model_id}'.",
|
||||
)
|
||||
|
||||
base_url = endpoint.rstrip("/").removesuffix("/v1")
|
||||
url = f"{base_url}/upstream/{model_id}/{path}"
|
||||
if request.url.query:
|
||||
url = f"{url}?{request.url.query}"
|
||||
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
content_type = request.headers.get("content-type")
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
api_key = config.api_keys.get(endpoint)
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
|
||||
body = await request.body()
|
||||
client: aiohttp.ClientSession = get_session(endpoint)
|
||||
try:
|
||||
resp = await client.request(request.method, url, data=body or None, headers=headers)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream request to {url} failed: {e}")
|
||||
|
||||
async def _iter():
|
||||
try:
|
||||
async for chunk in resp.content.iter_any():
|
||||
yield chunk
|
||||
finally:
|
||||
resp.release()
|
||||
|
||||
return StreamingResponse(
|
||||
_iter(),
|
||||
status_code=resp.status,
|
||||
media_type=resp.headers.get("Content-Type"),
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue