nomyo-router/api/openai.py

804 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""OpenAI-compatible routes (``/v1/embeddings``, ``/v1/chat/completions``,
``/v1/completions``, ``/v1/models``, ``/v1/rerank`` and ``/rerank``).
The chat-completions and completions handlers carry the full reactive-trim
logic for ``exceed_context_size_error`` plus connection-failure rerouting
(``_mark_backend_unhealthy``). The streaming branches assemble cached
responses on the fly so caching works for both streaming and non-streaming
clients.
"""
import asyncio
import base64
import math
import aiohttp
import orjson
from fastapi import APIRouter, HTTPException, Request
from starlette.responses import JSONResponse, StreamingResponse
from cache import get_llm_cache, openai_nonstream_to_sse
from config import get_config
from context_window import (
_count_message_tokens,
_trim_messages_for_context,
_calibrated_trim_target,
_endpoint_nctx,
_CTX_TRIM_SMALL_LIMIT,
)
from fingerprint import _conversation_fingerprint
from security import _mask_secrets
from state import token_queue, app_state, default_headers
from backends.health import _is_backend_connection_error, _mark_backend_unhealthy
from backends.normalize import (
dedupe_on_keys,
ep2base,
is_ext_openai_endpoint,
is_openai_compatible,
_normalize_llama_model_name,
)
from backends.probe import fetch
from backends.sessions import _make_openai_client, get_session
from requests.messages import _strip_assistant_prefill, _strip_images_from_messages
from requests.rechunk import rechunk
from routing import choose_endpoint, decrement_usage
router = APIRouter()
@router.post("/v1/embeddings")
async def openai_embedding_proxy(request: Request):
"""
Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
doc = payload.get("input")
# Normalize multimodal input: extract only text parts for embedding models
if isinstance(doc, list):
normalized = []
for item in doc:
if isinstance(item, dict):
# Multimodal content part - extract text only, skip images
if item.get("type") == "text":
normalized.append(item.get("text", ""))
# Skip image_url and other non-text types
else:
normalized.append(item)
doc = normalized if len(normalized) != 1 else normalized[0]
elif isinstance(doc, dict) and doc.get("type") == "text":
doc = doc.get("text", "")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not doc:
raise HTTPException(
status_code=400, detail="Missing required field 'input'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint, tracking_model = await choose_endpoint(model)
if is_openai_compatible(endpoint):
api_key = config.api_keys.get(endpoint, "no-key")
else:
api_key = "ollama"
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key)
try:
async_gen = await oclient.embeddings.create(input=doc, model=model)
result = async_gen.model_dump()
for item in result.get("data", []):
emb = item.get("embedding")
if emb:
item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb]
return JSONResponse(content=result)
finally:
await decrement_usage(endpoint, tracking_model)
@router.post("/v1/chat/completions")
async def openai_chat_completions_proxy(request: Request):
"""
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
messages = payload.get("messages")
frequency_penalty = payload.get("frequency_penalty")
presence_penalty = payload.get("presence_penalty")
response_format = payload.get("response_format")
seed = payload.get("seed")
stop = payload.get("stop")
stream = payload.get("stream")
stream_options = payload.get("stream_options")
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
max_completion_tokens = payload.get("max_completion_tokens")
tools = payload.get("tools")
logprobs = payload.get("logprobs")
top_logprobs = payload.get("top_logprobs")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not isinstance(messages, list):
raise HTTPException(
status_code=400, detail="Missing required field 'messages' (must be a list)"
)
if ":latest" in model:
model = model.split(":latest")
model = model[0]
messages = _strip_assistant_prefill(messages)
params = {
"messages": messages,
"model": model,
}
optional_params = {
"tools": tools,
"response_format": response_format,
"stream_options": stream_options or {"include_usage": True },
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"seed": seed,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"stop": stop,
"stream": stream,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
}
params.update({k: v for k, v in optional_params.items() if v is not None})
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Reject unsupported image formats (SVG) before doing any work
for _msg in messages:
for _item in (_msg.get("content") or []) if isinstance(_msg.get("content"), list) else []:
if _item.get("type") == "image_url":
_url = (_item.get("image_url") or {}).get("url", "")
if _url.startswith("data:image/svg") or _url.lower().endswith(".svg"):
raise HTTPException(
status_code=400,
detail="SVG images are not supported. Please convert the image to PNG or JPEG before sending.",
)
# Cache lookup — before endpoint selection
_cache = get_llm_cache()
if _cache is not None and _cache_enabled:
_cached = await _cache.get_chat("openai_chat", model, messages)
if _cached is not None:
if stream:
_sse = openai_nonstream_to_sse(_cached, model)
async def _serve_cached_ochat_stream():
yield _sse
return StreamingResponse(_serve_cached_ochat_stream(), media_type="text/event-stream")
else:
async def _serve_cached_ochat_json():
yield _cached
return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json")
# 2. Endpoint logic
_affinity_key = _conversation_fingerprint(model, messages, None)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Helpers and API call — done in handler scope so try/except works reliably
async def _normalize_images_in_messages(msgs: list) -> list:
"""Fetch remote image URLs and convert them to base64 data URLs so
Ollama/llama-server can handle them without making outbound HTTP requests."""
resolved = []
for msg in msgs:
content = msg.get("content")
if not isinstance(content, list):
resolved.append(msg)
continue
new_content = []
for item in content:
if item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url", "")
if url and not url.startswith("data:"):
try:
http: aiohttp.ClientSession = app_state["session"]
async with http.get(url) as resp:
ctype = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
img_bytes = await resp.read()
b64 = base64.b64encode(img_bytes).decode("utf-8")
new_content.append({
"type": "image_url",
"image_url": {"url": f"data:{ctype};base64,{b64}"}
})
except Exception as _ie:
print(f"[image] Failed to fetch image URL: {_ie}")
new_content.append(item)
else:
new_content.append(item)
else:
new_content.append(item)
resolved.append({**msg, "content": new_content})
return resolved
# Make the API call in handler scope — try/except inside async generators is unreliable
# with Starlette's streaming machinery, so we resolve errors here before the generator starts.
send_params = params
if not is_ext_openai_endpoint(endpoint):
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
send_params = {**params, "messages": resolved_msgs}
# Proactive trim: only for small-ctx models we've already seen run out of space
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
_pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2)
_pre_est = _count_message_tokens(send_params.get("messages", []))
if _pre_est > _pre_target:
_pre_msgs = send_params.get("messages", [])
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
_dropped = len(_pre_msgs) - len(_pre_trimmed)
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
send_params = {**send_params, "messages": _pre_trimmed}
try:
async_gen = await oclient.chat.completions.create(**send_params)
except Exception as e:
_e_str = str(e)
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
if "does not support tools" in _e_str:
# Model doesn't support tools — retry without them
print(f"[ochat] retry: no tools", flush=True)
try:
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
async_gen = await oclient.chat.completions.create(**params_without_tools)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_ctx_err:
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0)
actual_tokens = err_detail.get("n_prompt_tokens", 0)
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
if not n_ctx_limit:
import re as _re
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
if _m:
n_ctx_limit = int(_m.group(1))
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
if _m:
actual_tokens = int(_m.group(1))
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
if not n_ctx_limit:
await decrement_usage(endpoint, tracking_model)
raise
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
msgs_to_trim = send_params.get("messages", [])
try:
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
except Exception as _helper_exc:
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
await decrement_usage(endpoint, tracking_model)
raise
dropped = len(msgs_to_trim) - len(trimmed_messages)
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
try:
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
print(f"[ctx-trim] retry-1 ok", flush=True)
except Exception as e2:
_e2_str = str(e2)
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
# Still too large — tool definitions likely consuming too many tokens, strip them too
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
try:
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
print(f"[ctx-trim] retry-2 ok", flush=True)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_backend_connection_error(e):
# Upstream connection failed (e.g. llama-server in router mode
# whose delegated worker died). Mark (endpoint, model) so the
# next request reroutes; the client will retry this one.
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
await _mark_backend_unhealthy(endpoint, model, _e_str)
await decrement_usage(endpoint, tracking_model)
raise
elif "image input is not supported" in _e_str:
# Model doesn't support images — strip and retry
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
try:
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
# 4. Async generator — only streams the already-established async_gen
async def stream_ochat_response():
try:
if stream == True:
content_parts: list[str] = []
usage_snapshot: dict = {}
async for chunk in async_gen:
data = (
chunk.model_dump_json()
if hasattr(chunk, "model_dump_json")
else orjson.dumps(chunk)
)
if chunk.choices:
delta = chunk.choices[0].delta
has_content = delta.content is not None
has_reasoning = (
getattr(delta, "reasoning_content", None) is not None
or getattr(delta, "reasoning", None) is not None
)
has_tool_calls = getattr(delta, "tool_calls", None) is not None
if has_content or has_reasoning or has_tool_calls:
yield f"data: {data}\n\n".encode("utf-8")
if has_content and delta.content:
content_parts.append(delta.content)
elif chunk.usage is not None:
# Forward the usage-only final chunk (e.g. from llama-server)
yield f"data: {data}\n\n".encode("utf-8")
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Detect context exhaustion mid-generation for small-ctx models.
# Guard: skip if max_tokens was set in the request — finish_reason=length
# could just mean the caller's token budget was exhausted, not the context window.
_req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens")
if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok:
_inferred_nctx = (prompt_tok + comp_tok) or 0
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
# Cache assembled streaming response — before [DONE] so it always runs
if _cache is not None and _cache_enabled and content_parts:
assembled = orjson.dumps({
"model": model,
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": "stop"}],
**({"usage": usage_snapshot} if usage_snapshot else {}),
}) + b"\n"
try:
await _cache.set_chat("openai_chat", model, messages, assembled)
except Exception as _ce:
print(f"[cache] set_chat (openai_chat streaming) failed: {_ce}")
yield b"data: [DONE]\n\n"
else:
prompt_tok = 0
comp_tok = 0
if async_gen.usage is not None:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
else orjson.dumps(async_gen)
)
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response
if _cache is not None and _cache_enabled:
try:
await _cache.set_chat("openai_chat", model, messages, cache_bytes)
except Exception as _ce:
print(f"[cache] set_chat (openai_chat non-streaming) failed: {_ce}")
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_ochat_response(),
media_type="text/event-stream" if stream else "application/json",
)
@router.post("/v1/completions")
async def openai_completions_proxy(request: Request):
"""
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
prompt = payload.get("prompt")
frequency_penalty = payload.get("frequency_penalty")
presence_penalty = payload.get("presence_penalty")
seed = payload.get("seed")
stop = payload.get("stop")
stream = payload.get("stream")
stream_options = payload.get("stream_options")
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
max_completion_tokens = payload.get("max_completion_tokens")
suffix = payload.get("suffix")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not prompt:
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
"prompt": prompt,
"model": model,
}
optional_params = {
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options or {"include_usage": True },
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"max_completion_tokens": max_completion_tokens,
"suffix": suffix
}
params.update({k: v for k, v in optional_params.items() if v is not None})
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Cache lookup — completions prompt mapped to a single-turn messages list
_cache = get_llm_cache()
_compl_messages = [{"role": "user", "content": prompt}]
if _cache is not None and _cache_enabled:
_cached = await _cache.get_chat("openai_completions", model, _compl_messages)
if _cached is not None:
if stream:
_sse = openai_nonstream_to_sse(_cached, model)
async def _serve_cached_ocompl_stream():
yield _sse
return StreamingResponse(_serve_cached_ocompl_stream(), media_type="text/event-stream")
else:
async def _serve_cached_ocompl_json():
yield _cached
return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json")
# 2. Endpoint logic
_affinity_key = _conversation_fingerprint(model, None, prompt)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Async generator that streams completions data and decrements the counter
# Make the API call in handler scope (try/except inside async generators is unreliable)
try:
async_gen = await oclient.completions.create(**params)
except Exception as e:
if _is_backend_connection_error(e):
print(f"[ocompl] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
await _mark_backend_unhealthy(endpoint, model, str(e))
await decrement_usage(endpoint, tracking_model)
raise
async def stream_ocompletions_response(model=model):
try:
if stream == True:
text_parts: list[str] = []
usage_snapshot: dict = {}
async for chunk in async_gen:
data = (
chunk.model_dump_json()
if hasattr(chunk, "model_dump_json")
else orjson.dumps(chunk)
)
if chunk.choices:
choice = chunk.choices[0]
has_text = getattr(choice, "text", None) is not None
has_reasoning = (
getattr(choice, "reasoning_content", None) is not None
or getattr(choice, "reasoning", None) is not None
)
if has_text or has_reasoning or choice.finish_reason is not None:
yield f"data: {data}\n\n".encode("utf-8")
if has_text and choice.text:
text_parts.append(choice.text)
elif chunk.usage is not None:
# Forward the usage-only final chunk (e.g. from llama-server)
yield f"data: {data}\n\n".encode("utf-8")
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Cache assembled streaming response — before [DONE] so it always runs
if _cache is not None and _cache_enabled and text_parts:
assembled = orjson.dumps({
"model": model,
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(text_parts)}, "finish_reason": "stop"}],
**({"usage": usage_snapshot} if usage_snapshot else {}),
}) + b"\n"
try:
await _cache.set_chat("openai_completions", model, _compl_messages, assembled)
except Exception as _ce:
print(f"[cache] set_chat (openai_completions streaming) failed: {_ce}")
# Final DONE event
yield b"data: [DONE]\n\n"
else:
prompt_tok = 0
comp_tok = 0
if async_gen.usage is not None:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
else orjson.dumps(async_gen)
)
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response
if _cache is not None and _cache_enabled:
try:
await _cache.set_chat("openai_completions", model, _compl_messages, cache_bytes)
except Exception as _ce:
print(f"[cache] set_chat (openai_completions non-streaming) failed: {_ce}")
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_ocompletions_response(),
media_type="text/event-stream" if stream else "application/json",
)
@router.get("/v1/models")
async def openai_models_proxy(request: Request):
"""
Proxy an OpenAI API models request to Ollama and llama-server endpoints and reply with a unique list of models.
For Ollama endpoints: queries /api/tags (all models)
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
"""
config = get_config()
# 1. Query Ollama endpoints for all models via /api/tags
ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
# 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models
ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)]
# 3. Query llama-server endpoints for loaded models via /v1/models
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
llama_tasks = [
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
for ep in all_llama_endpoints
]
ollama_models = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
ext_openai_models = await asyncio.gather(*ext_openai_tasks) if ext_openai_tasks else []
llama_models = await asyncio.gather(*llama_tasks) if llama_tasks else []
models = {'data': []}
# Add Ollama models (if any)
if ollama_models:
for modellist in ollama_models:
for model in modellist:
if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
model['id'] = model.get('name', model.get('id', ''))
else:
model['name'] = model['id']
models['data'].append(model)
# Add external OpenAI models (if any)
if ext_openai_models:
for modellist in ext_openai_models:
for model in modellist:
if not "id" in model.keys():
model['id'] = model.get('name', model.get('id', ''))
else:
model['name'] = model['id']
models['data'].append(model)
# Add llama-server models (all available, not just loaded)
if llama_models:
for modellist in llama_models:
for model in modellist:
if not "id" in model.keys():
model['id'] = model.get('name', model.get('id', ''))
else:
model['name'] = model['id']
models['data'].append(model)
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
return JSONResponse(
content={"data": dedupe_on_keys(models['data'], ['name'])},
status_code=200,
)
@router.post("/v1/rerank")
@router.post("/rerank")
async def rerank_proxy(request: Request):
"""
Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
Compatible with the Jina/Cohere rerank API convention used by llama-server,
vLLM, and services such as Cohere and Jina AI.
Ollama does not natively support reranking; requests routed to a plain Ollama
endpoint will receive a 501 Not Implemented response.
Request body:
model (str, required) reranker model name
query (str, required) search query
documents (list[str], required) candidate documents to rank
top_n (int, optional) limit returned results (default: all)
return_documents (bool, optional) include document text in results
max_tokens_per_doc (int, optional) truncation limit per document
Response (Jina/Cohere-compatible):
{
"id": "...",
"model": "...",
"usage": {"prompt_tokens": N, "total_tokens": N},
"results": [{"index": 0, "relevance_score": 0.95}, ...]
}
"""
config = get_config()
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
query = payload.get("query")
documents = payload.get("documents")
if not model:
raise HTTPException(status_code=400, detail="Missing required field 'model'")
if not query:
raise HTTPException(status_code=400, detail="Missing required field 'query'")
if not isinstance(documents, list) or not documents:
raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Determine which endpoint serves this model
try:
endpoint, tracking_model = await choose_endpoint(model)
except RuntimeError as e:
raise HTTPException(status_code=404, detail=str(e))
# Ollama endpoints have no native rerank support
if not is_openai_compatible(endpoint):
await decrement_usage(endpoint, tracking_model)
raise HTTPException(
status_code=501,
detail=(
f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
"reranking. Use a llama-server or OpenAI-compatible endpoint with a "
"dedicated reranker model."
),
)
if ":latest" in model:
model = model.split(":latest")[0]
# Build upstream rerank request body forward only recognised fields
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
if optional_key in payload:
upstream_payload[optional_key] = payload[optional_key]
# Determine upstream URL:
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
# External OpenAI endpoints expose /rerank under their /v1 base
if endpoint in config.llama_server_endpoints:
# llama-server: endpoint may or may not already contain /v1
if "/v1" in endpoint:
rerank_url = f"{endpoint}/rerank"
else:
rerank_url = f"{endpoint}/v1/rerank"
else:
# External OpenAI-compatible: ep2base gives us the /v1 base
rerank_url = f"{ep2base(endpoint)}/rerank"
api_key = config.api_keys.get(endpoint, "no-key")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
client: aiohttp.ClientSession = get_session(endpoint)
try:
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
response_bytes = await resp.read()
if resp.status >= 400:
raise HTTPException(
status_code=resp.status,
detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
)
data = orjson.loads(response_bytes)
# Record token usage if the upstream returned a usage object
usage = data.get("usage") or {}
prompt_tok = usage.get("prompt_tokens") or 0
total_tok = usage.get("total_tokens") or 0
# For reranking there are no completion tokens; we record prompt tokens only
if prompt_tok or total_tok:
await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
return JSONResponse(content=data)
finally:
await decrement_usage(endpoint, tracking_model)