804 lines
37 KiB
Python
804 lines
37 KiB
Python
"""OpenAI-compatible routes (``/v1/embeddings``, ``/v1/chat/completions``,
|
||
``/v1/completions``, ``/v1/models``, ``/v1/rerank`` and ``/rerank``).
|
||
|
||
The chat-completions and completions handlers carry the full reactive-trim
|
||
logic for ``exceed_context_size_error`` plus connection-failure rerouting
|
||
(``_mark_backend_unhealthy``). The streaming branches assemble cached
|
||
responses on the fly so caching works for both streaming and non-streaming
|
||
clients.
|
||
"""
|
||
import asyncio
|
||
import base64
|
||
import math
|
||
|
||
import aiohttp
|
||
import orjson
|
||
from fastapi import APIRouter, HTTPException, Request
|
||
from starlette.responses import JSONResponse, StreamingResponse
|
||
|
||
from cache import get_llm_cache, openai_nonstream_to_sse
|
||
from config import get_config
|
||
from context_window import (
|
||
_count_message_tokens,
|
||
_trim_messages_for_context,
|
||
_calibrated_trim_target,
|
||
_endpoint_nctx,
|
||
_CTX_TRIM_SMALL_LIMIT,
|
||
)
|
||
from fingerprint import _conversation_fingerprint
|
||
from security import _mask_secrets
|
||
from state import token_queue, app_state, default_headers
|
||
from backends.health import _is_backend_connection_error, _mark_backend_unhealthy
|
||
from backends.normalize import (
|
||
dedupe_on_keys,
|
||
ep2base,
|
||
is_ext_openai_endpoint,
|
||
is_openai_compatible,
|
||
_normalize_llama_model_name,
|
||
)
|
||
from backends.probe import fetch
|
||
from backends.sessions import _make_openai_client, get_session
|
||
from requests.messages import _strip_assistant_prefill, _strip_images_from_messages
|
||
from requests.rechunk import rechunk
|
||
from routing import choose_endpoint, decrement_usage
|
||
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
@router.post("/v1/embeddings")
|
||
async def openai_embedding_proxy(request: Request):
|
||
"""
|
||
Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings.
|
||
|
||
"""
|
||
config = get_config()
|
||
# 1. Parse and validate request
|
||
try:
|
||
body_bytes = await request.body()
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
|
||
model = payload.get("model")
|
||
doc = payload.get("input")
|
||
|
||
# Normalize multimodal input: extract only text parts for embedding models
|
||
if isinstance(doc, list):
|
||
normalized = []
|
||
for item in doc:
|
||
if isinstance(item, dict):
|
||
# Multimodal content part - extract text only, skip images
|
||
if item.get("type") == "text":
|
||
normalized.append(item.get("text", ""))
|
||
# Skip image_url and other non-text types
|
||
else:
|
||
normalized.append(item)
|
||
doc = normalized if len(normalized) != 1 else normalized[0]
|
||
elif isinstance(doc, dict) and doc.get("type") == "text":
|
||
doc = doc.get("text", "")
|
||
|
||
if not model:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'model'"
|
||
)
|
||
if not doc:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'input'"
|
||
)
|
||
except orjson.JSONDecodeError as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||
|
||
# 2. Endpoint logic
|
||
endpoint, tracking_model = await choose_endpoint(model)
|
||
if is_openai_compatible(endpoint):
|
||
api_key = config.api_keys.get(endpoint, "no-key")
|
||
else:
|
||
api_key = "ollama"
|
||
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key)
|
||
|
||
try:
|
||
async_gen = await oclient.embeddings.create(input=doc, model=model)
|
||
result = async_gen.model_dump()
|
||
for item in result.get("data", []):
|
||
emb = item.get("embedding")
|
||
if emb:
|
||
item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb]
|
||
return JSONResponse(content=result)
|
||
finally:
|
||
await decrement_usage(endpoint, tracking_model)
|
||
|
||
|
||
@router.post("/v1/chat/completions")
|
||
async def openai_chat_completions_proxy(request: Request):
|
||
"""
|
||
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
||
|
||
"""
|
||
config = get_config()
|
||
# 1. Parse and validate request
|
||
try:
|
||
body_bytes = await request.body()
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
|
||
model = payload.get("model")
|
||
messages = payload.get("messages")
|
||
frequency_penalty = payload.get("frequency_penalty")
|
||
presence_penalty = payload.get("presence_penalty")
|
||
response_format = payload.get("response_format")
|
||
seed = payload.get("seed")
|
||
stop = payload.get("stop")
|
||
stream = payload.get("stream")
|
||
stream_options = payload.get("stream_options")
|
||
temperature = payload.get("temperature")
|
||
top_p = payload.get("top_p")
|
||
max_tokens = payload.get("max_tokens")
|
||
max_completion_tokens = payload.get("max_completion_tokens")
|
||
tools = payload.get("tools")
|
||
logprobs = payload.get("logprobs")
|
||
top_logprobs = payload.get("top_logprobs")
|
||
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
||
|
||
if not model:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'model'"
|
||
)
|
||
if not isinstance(messages, list):
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'messages' (must be a list)"
|
||
)
|
||
|
||
if ":latest" in model:
|
||
model = model.split(":latest")
|
||
model = model[0]
|
||
|
||
messages = _strip_assistant_prefill(messages)
|
||
params = {
|
||
"messages": messages,
|
||
"model": model,
|
||
}
|
||
|
||
optional_params = {
|
||
"tools": tools,
|
||
"response_format": response_format,
|
||
"stream_options": stream_options or {"include_usage": True },
|
||
"max_completion_tokens": max_completion_tokens,
|
||
"max_tokens": max_tokens,
|
||
"temperature": temperature,
|
||
"top_p": top_p,
|
||
"seed": seed,
|
||
"presence_penalty": presence_penalty,
|
||
"frequency_penalty": frequency_penalty,
|
||
"stop": stop,
|
||
"stream": stream,
|
||
"logprobs": logprobs,
|
||
"top_logprobs": top_logprobs,
|
||
}
|
||
|
||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||
except orjson.JSONDecodeError as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||
|
||
# Reject unsupported image formats (SVG) before doing any work
|
||
for _msg in messages:
|
||
for _item in (_msg.get("content") or []) if isinstance(_msg.get("content"), list) else []:
|
||
if _item.get("type") == "image_url":
|
||
_url = (_item.get("image_url") or {}).get("url", "")
|
||
if _url.startswith("data:image/svg") or _url.lower().endswith(".svg"):
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="SVG images are not supported. Please convert the image to PNG or JPEG before sending.",
|
||
)
|
||
|
||
# Cache lookup — before endpoint selection
|
||
_cache = get_llm_cache()
|
||
if _cache is not None and _cache_enabled:
|
||
_cached = await _cache.get_chat("openai_chat", model, messages)
|
||
if _cached is not None:
|
||
if stream:
|
||
_sse = openai_nonstream_to_sse(_cached, model)
|
||
async def _serve_cached_ochat_stream():
|
||
yield _sse
|
||
return StreamingResponse(_serve_cached_ochat_stream(), media_type="text/event-stream")
|
||
else:
|
||
async def _serve_cached_ochat_json():
|
||
yield _cached
|
||
return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json")
|
||
|
||
# 2. Endpoint logic
|
||
_affinity_key = _conversation_fingerprint(model, messages, None)
|
||
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
||
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||
# 3. Helpers and API call — done in handler scope so try/except works reliably
|
||
async def _normalize_images_in_messages(msgs: list) -> list:
|
||
"""Fetch remote image URLs and convert them to base64 data URLs so
|
||
Ollama/llama-server can handle them without making outbound HTTP requests."""
|
||
resolved = []
|
||
for msg in msgs:
|
||
content = msg.get("content")
|
||
if not isinstance(content, list):
|
||
resolved.append(msg)
|
||
continue
|
||
new_content = []
|
||
for item in content:
|
||
if item.get("type") == "image_url":
|
||
url = (item.get("image_url") or {}).get("url", "")
|
||
if url and not url.startswith("data:"):
|
||
try:
|
||
http: aiohttp.ClientSession = app_state["session"]
|
||
async with http.get(url) as resp:
|
||
ctype = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
|
||
img_bytes = await resp.read()
|
||
b64 = base64.b64encode(img_bytes).decode("utf-8")
|
||
new_content.append({
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:{ctype};base64,{b64}"}
|
||
})
|
||
except Exception as _ie:
|
||
print(f"[image] Failed to fetch image URL: {_ie}")
|
||
new_content.append(item)
|
||
else:
|
||
new_content.append(item)
|
||
else:
|
||
new_content.append(item)
|
||
resolved.append({**msg, "content": new_content})
|
||
return resolved
|
||
|
||
# Make the API call in handler scope — try/except inside async generators is unreliable
|
||
# with Starlette's streaming machinery, so we resolve errors here before the generator starts.
|
||
send_params = params
|
||
if not is_ext_openai_endpoint(endpoint):
|
||
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
|
||
send_params = {**params, "messages": resolved_msgs}
|
||
# Proactive trim: only for small-ctx models we've already seen run out of space
|
||
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
|
||
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
|
||
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||
_pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2)
|
||
_pre_est = _count_message_tokens(send_params.get("messages", []))
|
||
if _pre_est > _pre_target:
|
||
_pre_msgs = send_params.get("messages", [])
|
||
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
|
||
_dropped = len(_pre_msgs) - len(_pre_trimmed)
|
||
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
|
||
send_params = {**send_params, "messages": _pre_trimmed}
|
||
try:
|
||
async_gen = await oclient.chat.completions.create(**send_params)
|
||
except Exception as e:
|
||
_e_str = str(e)
|
||
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
|
||
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
|
||
if "does not support tools" in _e_str:
|
||
# Model doesn't support tools — retry without them
|
||
print(f"[ochat] retry: no tools", flush=True)
|
||
try:
|
||
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
|
||
async_gen = await oclient.chat.completions.create(**params_without_tools)
|
||
except Exception:
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
elif _is_ctx_err:
|
||
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
|
||
err_body = getattr(e, "body", {}) or {}
|
||
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||
n_ctx_limit = err_detail.get("n_ctx", 0)
|
||
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
||
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
|
||
if not n_ctx_limit:
|
||
import re as _re
|
||
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
|
||
if _m:
|
||
n_ctx_limit = int(_m.group(1))
|
||
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
|
||
if _m:
|
||
actual_tokens = int(_m.group(1))
|
||
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
|
||
if not n_ctx_limit:
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
||
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
||
|
||
msgs_to_trim = send_params.get("messages", [])
|
||
try:
|
||
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
||
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
||
except Exception as _helper_exc:
|
||
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
dropped = len(msgs_to_trim) - len(trimmed_messages)
|
||
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
|
||
try:
|
||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
|
||
print(f"[ctx-trim] retry-1 ok", flush=True)
|
||
except Exception as e2:
|
||
_e2_str = str(e2)
|
||
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
|
||
# Still too large — tool definitions likely consuming too many tokens, strip them too
|
||
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
|
||
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
|
||
try:
|
||
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
|
||
print(f"[ctx-trim] retry-2 ok", flush=True)
|
||
except Exception:
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
else:
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
elif _is_backend_connection_error(e):
|
||
# Upstream connection failed (e.g. llama-server in router mode
|
||
# whose delegated worker died). Mark (endpoint, model) so the
|
||
# next request reroutes; the client will retry this one.
|
||
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
||
await _mark_backend_unhealthy(endpoint, model, _e_str)
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
elif "image input is not supported" in _e_str:
|
||
# Model doesn't support images — strip and retry
|
||
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
|
||
try:
|
||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
|
||
except Exception:
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
else:
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
|
||
# 4. Async generator — only streams the already-established async_gen
|
||
async def stream_ochat_response():
|
||
try:
|
||
if stream == True:
|
||
content_parts: list[str] = []
|
||
usage_snapshot: dict = {}
|
||
async for chunk in async_gen:
|
||
data = (
|
||
chunk.model_dump_json()
|
||
if hasattr(chunk, "model_dump_json")
|
||
else orjson.dumps(chunk)
|
||
)
|
||
if chunk.choices:
|
||
delta = chunk.choices[0].delta
|
||
has_content = delta.content is not None
|
||
has_reasoning = (
|
||
getattr(delta, "reasoning_content", None) is not None
|
||
or getattr(delta, "reasoning", None) is not None
|
||
)
|
||
has_tool_calls = getattr(delta, "tool_calls", None) is not None
|
||
if has_content or has_reasoning or has_tool_calls:
|
||
yield f"data: {data}\n\n".encode("utf-8")
|
||
if has_content and delta.content:
|
||
content_parts.append(delta.content)
|
||
elif chunk.usage is not None:
|
||
# Forward the usage-only final chunk (e.g. from llama-server)
|
||
yield f"data: {data}\n\n".encode("utf-8")
|
||
prompt_tok = 0
|
||
comp_tok = 0
|
||
if chunk.usage is not None:
|
||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||
comp_tok = chunk.usage.completion_tokens or 0
|
||
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
|
||
else:
|
||
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
||
if llama_usage:
|
||
prompt_tok, comp_tok = llama_usage
|
||
if prompt_tok != 0 or comp_tok != 0:
|
||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||
# Detect context exhaustion mid-generation for small-ctx models.
|
||
# Guard: skip if max_tokens was set in the request — finish_reason=length
|
||
# could just mean the caller's token budget was exhausted, not the context window.
|
||
_req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens")
|
||
if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok:
|
||
_inferred_nctx = (prompt_tok + comp_tok) or 0
|
||
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
|
||
print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
|
||
# Cache assembled streaming response — before [DONE] so it always runs
|
||
if _cache is not None and _cache_enabled and content_parts:
|
||
assembled = orjson.dumps({
|
||
"model": model,
|
||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": "stop"}],
|
||
**({"usage": usage_snapshot} if usage_snapshot else {}),
|
||
}) + b"\n"
|
||
try:
|
||
await _cache.set_chat("openai_chat", model, messages, assembled)
|
||
except Exception as _ce:
|
||
print(f"[cache] set_chat (openai_chat streaming) failed: {_ce}")
|
||
yield b"data: [DONE]\n\n"
|
||
else:
|
||
prompt_tok = 0
|
||
comp_tok = 0
|
||
if async_gen.usage is not None:
|
||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||
comp_tok = async_gen.usage.completion_tokens or 0
|
||
else:
|
||
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
||
if llama_usage:
|
||
prompt_tok, comp_tok = llama_usage
|
||
if prompt_tok != 0 or comp_tok != 0:
|
||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||
json_line = (
|
||
async_gen.model_dump_json()
|
||
if hasattr(async_gen, "model_dump_json")
|
||
else orjson.dumps(async_gen)
|
||
)
|
||
cache_bytes = json_line.encode("utf-8") + b"\n"
|
||
yield cache_bytes
|
||
# Cache non-streaming response
|
||
if _cache is not None and _cache_enabled:
|
||
try:
|
||
await _cache.set_chat("openai_chat", model, messages, cache_bytes)
|
||
except Exception as _ce:
|
||
print(f"[cache] set_chat (openai_chat non-streaming) failed: {_ce}")
|
||
|
||
finally:
|
||
# Ensure counter is decremented even if an exception occurs
|
||
await decrement_usage(endpoint, tracking_model)
|
||
|
||
# 4. Return a StreamingResponse backed by the generator
|
||
return StreamingResponse(
|
||
stream_ochat_response(),
|
||
media_type="text/event-stream" if stream else "application/json",
|
||
)
|
||
|
||
|
||
@router.post("/v1/completions")
|
||
async def openai_completions_proxy(request: Request):
|
||
"""
|
||
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
||
|
||
"""
|
||
config = get_config()
|
||
# 1. Parse and validate request
|
||
try:
|
||
body_bytes = await request.body()
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
|
||
model = payload.get("model")
|
||
prompt = payload.get("prompt")
|
||
frequency_penalty = payload.get("frequency_penalty")
|
||
presence_penalty = payload.get("presence_penalty")
|
||
seed = payload.get("seed")
|
||
stop = payload.get("stop")
|
||
stream = payload.get("stream")
|
||
stream_options = payload.get("stream_options")
|
||
temperature = payload.get("temperature")
|
||
top_p = payload.get("top_p")
|
||
max_tokens = payload.get("max_tokens")
|
||
max_completion_tokens = payload.get("max_completion_tokens")
|
||
suffix = payload.get("suffix")
|
||
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
||
|
||
if not model:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'model'"
|
||
)
|
||
if not prompt:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'prompt'"
|
||
)
|
||
|
||
if ":latest" in model:
|
||
model = model.split(":latest")
|
||
model = model[0]
|
||
|
||
params = {
|
||
"prompt": prompt,
|
||
"model": model,
|
||
}
|
||
|
||
optional_params = {
|
||
"frequency_penalty": frequency_penalty,
|
||
"presence_penalty": presence_penalty,
|
||
"seed": seed,
|
||
"stop": stop,
|
||
"stream": stream,
|
||
"stream_options": stream_options or {"include_usage": True },
|
||
"temperature": temperature,
|
||
"top_p": top_p,
|
||
"max_tokens": max_tokens,
|
||
"max_completion_tokens": max_completion_tokens,
|
||
"suffix": suffix
|
||
}
|
||
|
||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||
except orjson.JSONDecodeError as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||
|
||
# Cache lookup — completions prompt mapped to a single-turn messages list
|
||
_cache = get_llm_cache()
|
||
_compl_messages = [{"role": "user", "content": prompt}]
|
||
if _cache is not None and _cache_enabled:
|
||
_cached = await _cache.get_chat("openai_completions", model, _compl_messages)
|
||
if _cached is not None:
|
||
if stream:
|
||
_sse = openai_nonstream_to_sse(_cached, model)
|
||
async def _serve_cached_ocompl_stream():
|
||
yield _sse
|
||
return StreamingResponse(_serve_cached_ocompl_stream(), media_type="text/event-stream")
|
||
else:
|
||
async def _serve_cached_ocompl_json():
|
||
yield _cached
|
||
return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json")
|
||
|
||
# 2. Endpoint logic
|
||
_affinity_key = _conversation_fingerprint(model, None, prompt)
|
||
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
||
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||
|
||
# 3. Async generator that streams completions data and decrements the counter
|
||
# Make the API call in handler scope (try/except inside async generators is unreliable)
|
||
try:
|
||
async_gen = await oclient.completions.create(**params)
|
||
except Exception as e:
|
||
if _is_backend_connection_error(e):
|
||
print(f"[ocompl] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
||
await _mark_backend_unhealthy(endpoint, model, str(e))
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
|
||
async def stream_ocompletions_response(model=model):
|
||
try:
|
||
if stream == True:
|
||
text_parts: list[str] = []
|
||
usage_snapshot: dict = {}
|
||
async for chunk in async_gen:
|
||
data = (
|
||
chunk.model_dump_json()
|
||
if hasattr(chunk, "model_dump_json")
|
||
else orjson.dumps(chunk)
|
||
)
|
||
if chunk.choices:
|
||
choice = chunk.choices[0]
|
||
has_text = getattr(choice, "text", None) is not None
|
||
has_reasoning = (
|
||
getattr(choice, "reasoning_content", None) is not None
|
||
or getattr(choice, "reasoning", None) is not None
|
||
)
|
||
if has_text or has_reasoning or choice.finish_reason is not None:
|
||
yield f"data: {data}\n\n".encode("utf-8")
|
||
if has_text and choice.text:
|
||
text_parts.append(choice.text)
|
||
elif chunk.usage is not None:
|
||
# Forward the usage-only final chunk (e.g. from llama-server)
|
||
yield f"data: {data}\n\n".encode("utf-8")
|
||
prompt_tok = 0
|
||
comp_tok = 0
|
||
if chunk.usage is not None:
|
||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||
comp_tok = chunk.usage.completion_tokens or 0
|
||
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
|
||
else:
|
||
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
||
if llama_usage:
|
||
prompt_tok, comp_tok = llama_usage
|
||
if prompt_tok != 0 or comp_tok != 0:
|
||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||
# Cache assembled streaming response — before [DONE] so it always runs
|
||
if _cache is not None and _cache_enabled and text_parts:
|
||
assembled = orjson.dumps({
|
||
"model": model,
|
||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(text_parts)}, "finish_reason": "stop"}],
|
||
**({"usage": usage_snapshot} if usage_snapshot else {}),
|
||
}) + b"\n"
|
||
try:
|
||
await _cache.set_chat("openai_completions", model, _compl_messages, assembled)
|
||
except Exception as _ce:
|
||
print(f"[cache] set_chat (openai_completions streaming) failed: {_ce}")
|
||
# Final DONE event
|
||
yield b"data: [DONE]\n\n"
|
||
else:
|
||
prompt_tok = 0
|
||
comp_tok = 0
|
||
if async_gen.usage is not None:
|
||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||
comp_tok = async_gen.usage.completion_tokens or 0
|
||
else:
|
||
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
||
if llama_usage:
|
||
prompt_tok, comp_tok = llama_usage
|
||
if prompt_tok != 0 or comp_tok != 0:
|
||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||
json_line = (
|
||
async_gen.model_dump_json()
|
||
if hasattr(async_gen, "model_dump_json")
|
||
else orjson.dumps(async_gen)
|
||
)
|
||
cache_bytes = json_line.encode("utf-8") + b"\n"
|
||
yield cache_bytes
|
||
# Cache non-streaming response
|
||
if _cache is not None and _cache_enabled:
|
||
try:
|
||
await _cache.set_chat("openai_completions", model, _compl_messages, cache_bytes)
|
||
except Exception as _ce:
|
||
print(f"[cache] set_chat (openai_completions non-streaming) failed: {_ce}")
|
||
|
||
finally:
|
||
# Ensure counter is decremented even if an exception occurs
|
||
await decrement_usage(endpoint, tracking_model)
|
||
|
||
# 4. Return a StreamingResponse backed by the generator
|
||
return StreamingResponse(
|
||
stream_ocompletions_response(),
|
||
media_type="text/event-stream" if stream else "application/json",
|
||
)
|
||
|
||
|
||
@router.get("/v1/models")
|
||
async def openai_models_proxy(request: Request):
|
||
"""
|
||
Proxy an OpenAI API models request to Ollama and llama-server endpoints and reply with a unique list of models.
|
||
|
||
For Ollama endpoints: queries /api/tags (all models)
|
||
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
|
||
"""
|
||
config = get_config()
|
||
# 1. Query Ollama endpoints for all models via /api/tags
|
||
ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
|
||
# 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models
|
||
ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)]
|
||
# 3. Query llama-server endpoints for loaded models via /v1/models
|
||
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
||
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
||
llama_tasks = [
|
||
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
|
||
for ep in all_llama_endpoints
|
||
]
|
||
|
||
ollama_models = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
|
||
ext_openai_models = await asyncio.gather(*ext_openai_tasks) if ext_openai_tasks else []
|
||
llama_models = await asyncio.gather(*llama_tasks) if llama_tasks else []
|
||
|
||
models = {'data': []}
|
||
|
||
# Add Ollama models (if any)
|
||
if ollama_models:
|
||
for modellist in ollama_models:
|
||
for model in modellist:
|
||
if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
|
||
model['id'] = model.get('name', model.get('id', ''))
|
||
else:
|
||
model['name'] = model['id']
|
||
models['data'].append(model)
|
||
|
||
# Add external OpenAI models (if any)
|
||
if ext_openai_models:
|
||
for modellist in ext_openai_models:
|
||
for model in modellist:
|
||
if not "id" in model.keys():
|
||
model['id'] = model.get('name', model.get('id', ''))
|
||
else:
|
||
model['name'] = model['id']
|
||
models['data'].append(model)
|
||
|
||
# Add llama-server models (all available, not just loaded)
|
||
if llama_models:
|
||
for modellist in llama_models:
|
||
for model in modellist:
|
||
if not "id" in model.keys():
|
||
model['id'] = model.get('name', model.get('id', ''))
|
||
else:
|
||
model['name'] = model['id']
|
||
models['data'].append(model)
|
||
|
||
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
||
return JSONResponse(
|
||
content={"data": dedupe_on_keys(models['data'], ['name'])},
|
||
status_code=200,
|
||
)
|
||
|
||
|
||
@router.post("/v1/rerank")
|
||
@router.post("/rerank")
|
||
async def rerank_proxy(request: Request):
|
||
"""
|
||
Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
|
||
|
||
Compatible with the Jina/Cohere rerank API convention used by llama-server,
|
||
vLLM, and services such as Cohere and Jina AI.
|
||
|
||
Ollama does not natively support reranking; requests routed to a plain Ollama
|
||
endpoint will receive a 501 Not Implemented response.
|
||
|
||
Request body:
|
||
model (str, required) – reranker model name
|
||
query (str, required) – search query
|
||
documents (list[str], required) – candidate documents to rank
|
||
top_n (int, optional) – limit returned results (default: all)
|
||
return_documents (bool, optional) – include document text in results
|
||
max_tokens_per_doc (int, optional) – truncation limit per document
|
||
|
||
Response (Jina/Cohere-compatible):
|
||
{
|
||
"id": "...",
|
||
"model": "...",
|
||
"usage": {"prompt_tokens": N, "total_tokens": N},
|
||
"results": [{"index": 0, "relevance_score": 0.95}, ...]
|
||
}
|
||
"""
|
||
config = get_config()
|
||
try:
|
||
body_bytes = await request.body()
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
|
||
model = payload.get("model")
|
||
query = payload.get("query")
|
||
documents = payload.get("documents")
|
||
|
||
if not model:
|
||
raise HTTPException(status_code=400, detail="Missing required field 'model'")
|
||
if not query:
|
||
raise HTTPException(status_code=400, detail="Missing required field 'query'")
|
||
if not isinstance(documents, list) or not documents:
|
||
raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
|
||
except orjson.JSONDecodeError as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||
|
||
# Determine which endpoint serves this model
|
||
try:
|
||
endpoint, tracking_model = await choose_endpoint(model)
|
||
except RuntimeError as e:
|
||
raise HTTPException(status_code=404, detail=str(e))
|
||
|
||
# Ollama endpoints have no native rerank support
|
||
if not is_openai_compatible(endpoint):
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise HTTPException(
|
||
status_code=501,
|
||
detail=(
|
||
f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
|
||
"reranking. Use a llama-server or OpenAI-compatible endpoint with a "
|
||
"dedicated reranker model."
|
||
),
|
||
)
|
||
|
||
if ":latest" in model:
|
||
model = model.split(":latest")[0]
|
||
|
||
# Build upstream rerank request body – forward only recognised fields
|
||
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
|
||
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
|
||
if optional_key in payload:
|
||
upstream_payload[optional_key] = payload[optional_key]
|
||
|
||
# Determine upstream URL:
|
||
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
|
||
# External OpenAI endpoints expose /rerank under their /v1 base
|
||
if endpoint in config.llama_server_endpoints:
|
||
# llama-server: endpoint may or may not already contain /v1
|
||
if "/v1" in endpoint:
|
||
rerank_url = f"{endpoint}/rerank"
|
||
else:
|
||
rerank_url = f"{endpoint}/v1/rerank"
|
||
else:
|
||
# External OpenAI-compatible: ep2base gives us the /v1 base
|
||
rerank_url = f"{ep2base(endpoint)}/rerank"
|
||
|
||
api_key = config.api_keys.get(endpoint, "no-key")
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {api_key}",
|
||
}
|
||
|
||
client: aiohttp.ClientSession = get_session(endpoint)
|
||
try:
|
||
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
|
||
response_bytes = await resp.read()
|
||
if resp.status >= 400:
|
||
raise HTTPException(
|
||
status_code=resp.status,
|
||
detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
|
||
)
|
||
data = orjson.loads(response_bytes)
|
||
|
||
# Record token usage if the upstream returned a usage object
|
||
usage = data.get("usage") or {}
|
||
prompt_tok = usage.get("prompt_tokens") or 0
|
||
total_tok = usage.get("total_tokens") or 0
|
||
# For reranking there are no completion tokens; we record prompt tokens only
|
||
if prompt_tok or total_tok:
|
||
await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
|
||
|
||
return JSONResponse(content=data)
|
||
finally:
|
||
await decrement_usage(endpoint, tracking_model)
|