805 lines
37 KiB
Python
805 lines
37 KiB
Python
|
|
"""OpenAI-compatible routes (``/v1/embeddings``, ``/v1/chat/completions``,
|
|||
|
|
``/v1/completions``, ``/v1/models``, ``/v1/rerank`` and ``/rerank``).
|
|||
|
|
|
|||
|
|
The chat-completions and completions handlers carry the full reactive-trim
|
|||
|
|
logic for ``exceed_context_size_error`` plus connection-failure rerouting
|
|||
|
|
(``_mark_backend_unhealthy``). The streaming branches assemble cached
|
|||
|
|
responses on the fly so caching works for both streaming and non-streaming
|
|||
|
|
clients.
|
|||
|
|
"""
|
|||
|
|
import asyncio
|
|||
|
|
import base64
|
|||
|
|
import math
|
|||
|
|
|
|||
|
|
import aiohttp
|
|||
|
|
import orjson
|
|||
|
|
from fastapi import APIRouter, HTTPException, Request
|
|||
|
|
from starlette.responses import JSONResponse, StreamingResponse
|
|||
|
|
|
|||
|
|
from cache import get_llm_cache, openai_nonstream_to_sse
|
|||
|
|
from config import get_config
|
|||
|
|
from context_window import (
|
|||
|
|
_count_message_tokens,
|
|||
|
|
_trim_messages_for_context,
|
|||
|
|
_calibrated_trim_target,
|
|||
|
|
_endpoint_nctx,
|
|||
|
|
_CTX_TRIM_SMALL_LIMIT,
|
|||
|
|
)
|
|||
|
|
from fingerprint import _conversation_fingerprint
|
|||
|
|
from security import _mask_secrets
|
|||
|
|
from state import token_queue, app_state, default_headers
|
|||
|
|
from backends.health import _is_backend_connection_error, _mark_backend_unhealthy
|
|||
|
|
from backends.normalize import (
|
|||
|
|
dedupe_on_keys,
|
|||
|
|
ep2base,
|
|||
|
|
is_ext_openai_endpoint,
|
|||
|
|
is_openai_compatible,
|
|||
|
|
_normalize_llama_model_name,
|
|||
|
|
)
|
|||
|
|
from backends.probe import fetch
|
|||
|
|
from backends.sessions import _make_openai_client, get_session
|
|||
|
|
from requests.messages import _strip_assistant_prefill, _strip_images_from_messages
|
|||
|
|
from requests.rechunk import rechunk
|
|||
|
|
from routing import choose_endpoint, decrement_usage
|
|||
|
|
|
|||
|
|
|
|||
|
|
router = APIRouter()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/v1/embeddings")
|
|||
|
|
async def openai_embedding_proxy(request: Request):
|
|||
|
|
"""
|
|||
|
|
Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings.
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
config = get_config()
|
|||
|
|
# 1. Parse and validate request
|
|||
|
|
try:
|
|||
|
|
body_bytes = await request.body()
|
|||
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|||
|
|
|
|||
|
|
model = payload.get("model")
|
|||
|
|
doc = payload.get("input")
|
|||
|
|
|
|||
|
|
# Normalize multimodal input: extract only text parts for embedding models
|
|||
|
|
if isinstance(doc, list):
|
|||
|
|
normalized = []
|
|||
|
|
for item in doc:
|
|||
|
|
if isinstance(item, dict):
|
|||
|
|
# Multimodal content part - extract text only, skip images
|
|||
|
|
if item.get("type") == "text":
|
|||
|
|
normalized.append(item.get("text", ""))
|
|||
|
|
# Skip image_url and other non-text types
|
|||
|
|
else:
|
|||
|
|
normalized.append(item)
|
|||
|
|
doc = normalized if len(normalized) != 1 else normalized[0]
|
|||
|
|
elif isinstance(doc, dict) and doc.get("type") == "text":
|
|||
|
|
doc = doc.get("text", "")
|
|||
|
|
|
|||
|
|
if not model:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=400, detail="Missing required field 'model'"
|
|||
|
|
)
|
|||
|
|
if not doc:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=400, detail="Missing required field 'input'"
|
|||
|
|
)
|
|||
|
|
except orjson.JSONDecodeError as e:
|
|||
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|||
|
|
|
|||
|
|
# 2. Endpoint logic
|
|||
|
|
endpoint, tracking_model = await choose_endpoint(model)
|
|||
|
|
if is_openai_compatible(endpoint):
|
|||
|
|
api_key = config.api_keys.get(endpoint, "no-key")
|
|||
|
|
else:
|
|||
|
|
api_key = "ollama"
|
|||
|
|
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
async_gen = await oclient.embeddings.create(input=doc, model=model)
|
|||
|
|
result = async_gen.model_dump()
|
|||
|
|
for item in result.get("data", []):
|
|||
|
|
emb = item.get("embedding")
|
|||
|
|
if emb:
|
|||
|
|
item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb]
|
|||
|
|
return JSONResponse(content=result)
|
|||
|
|
finally:
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/v1/chat/completions")
|
|||
|
|
async def openai_chat_completions_proxy(request: Request):
|
|||
|
|
"""
|
|||
|
|
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
config = get_config()
|
|||
|
|
# 1. Parse and validate request
|
|||
|
|
try:
|
|||
|
|
body_bytes = await request.body()
|
|||
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|||
|
|
|
|||
|
|
model = payload.get("model")
|
|||
|
|
messages = payload.get("messages")
|
|||
|
|
frequency_penalty = payload.get("frequency_penalty")
|
|||
|
|
presence_penalty = payload.get("presence_penalty")
|
|||
|
|
response_format = payload.get("response_format")
|
|||
|
|
seed = payload.get("seed")
|
|||
|
|
stop = payload.get("stop")
|
|||
|
|
stream = payload.get("stream")
|
|||
|
|
stream_options = payload.get("stream_options")
|
|||
|
|
temperature = payload.get("temperature")
|
|||
|
|
top_p = payload.get("top_p")
|
|||
|
|
max_tokens = payload.get("max_tokens")
|
|||
|
|
max_completion_tokens = payload.get("max_completion_tokens")
|
|||
|
|
tools = payload.get("tools")
|
|||
|
|
logprobs = payload.get("logprobs")
|
|||
|
|
top_logprobs = payload.get("top_logprobs")
|
|||
|
|
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
|||
|
|
|
|||
|
|
if not model:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=400, detail="Missing required field 'model'"
|
|||
|
|
)
|
|||
|
|
if not isinstance(messages, list):
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=400, detail="Missing required field 'messages' (must be a list)"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if ":latest" in model:
|
|||
|
|
model = model.split(":latest")
|
|||
|
|
model = model[0]
|
|||
|
|
|
|||
|
|
messages = _strip_assistant_prefill(messages)
|
|||
|
|
params = {
|
|||
|
|
"messages": messages,
|
|||
|
|
"model": model,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
optional_params = {
|
|||
|
|
"tools": tools,
|
|||
|
|
"response_format": response_format,
|
|||
|
|
"stream_options": stream_options or {"include_usage": True },
|
|||
|
|
"max_completion_tokens": max_completion_tokens,
|
|||
|
|
"max_tokens": max_tokens,
|
|||
|
|
"temperature": temperature,
|
|||
|
|
"top_p": top_p,
|
|||
|
|
"seed": seed,
|
|||
|
|
"presence_penalty": presence_penalty,
|
|||
|
|
"frequency_penalty": frequency_penalty,
|
|||
|
|
"stop": stop,
|
|||
|
|
"stream": stream,
|
|||
|
|
"logprobs": logprobs,
|
|||
|
|
"top_logprobs": top_logprobs,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
|||
|
|
except orjson.JSONDecodeError as e:
|
|||
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|||
|
|
|
|||
|
|
# Reject unsupported image formats (SVG) before doing any work
|
|||
|
|
for _msg in messages:
|
|||
|
|
for _item in (_msg.get("content") or []) if isinstance(_msg.get("content"), list) else []:
|
|||
|
|
if _item.get("type") == "image_url":
|
|||
|
|
_url = (_item.get("image_url") or {}).get("url", "")
|
|||
|
|
if _url.startswith("data:image/svg") or _url.lower().endswith(".svg"):
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=400,
|
|||
|
|
detail="SVG images are not supported. Please convert the image to PNG or JPEG before sending.",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Cache lookup — before endpoint selection
|
|||
|
|
_cache = get_llm_cache()
|
|||
|
|
if _cache is not None and _cache_enabled:
|
|||
|
|
_cached = await _cache.get_chat("openai_chat", model, messages)
|
|||
|
|
if _cached is not None:
|
|||
|
|
if stream:
|
|||
|
|
_sse = openai_nonstream_to_sse(_cached, model)
|
|||
|
|
async def _serve_cached_ochat_stream():
|
|||
|
|
yield _sse
|
|||
|
|
return StreamingResponse(_serve_cached_ochat_stream(), media_type="text/event-stream")
|
|||
|
|
else:
|
|||
|
|
async def _serve_cached_ochat_json():
|
|||
|
|
yield _cached
|
|||
|
|
return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json")
|
|||
|
|
|
|||
|
|
# 2. Endpoint logic
|
|||
|
|
_affinity_key = _conversation_fingerprint(model, messages, None)
|
|||
|
|
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
|||
|
|
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
|||
|
|
# 3. Helpers and API call — done in handler scope so try/except works reliably
|
|||
|
|
async def _normalize_images_in_messages(msgs: list) -> list:
|
|||
|
|
"""Fetch remote image URLs and convert them to base64 data URLs so
|
|||
|
|
Ollama/llama-server can handle them without making outbound HTTP requests."""
|
|||
|
|
resolved = []
|
|||
|
|
for msg in msgs:
|
|||
|
|
content = msg.get("content")
|
|||
|
|
if not isinstance(content, list):
|
|||
|
|
resolved.append(msg)
|
|||
|
|
continue
|
|||
|
|
new_content = []
|
|||
|
|
for item in content:
|
|||
|
|
if item.get("type") == "image_url":
|
|||
|
|
url = (item.get("image_url") or {}).get("url", "")
|
|||
|
|
if url and not url.startswith("data:"):
|
|||
|
|
try:
|
|||
|
|
http: aiohttp.ClientSession = app_state["session"]
|
|||
|
|
async with http.get(url) as resp:
|
|||
|
|
ctype = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
|
|||
|
|
img_bytes = await resp.read()
|
|||
|
|
b64 = base64.b64encode(img_bytes).decode("utf-8")
|
|||
|
|
new_content.append({
|
|||
|
|
"type": "image_url",
|
|||
|
|
"image_url": {"url": f"data:{ctype};base64,{b64}"}
|
|||
|
|
})
|
|||
|
|
except Exception as _ie:
|
|||
|
|
print(f"[image] Failed to fetch image URL: {_ie}")
|
|||
|
|
new_content.append(item)
|
|||
|
|
else:
|
|||
|
|
new_content.append(item)
|
|||
|
|
else:
|
|||
|
|
new_content.append(item)
|
|||
|
|
resolved.append({**msg, "content": new_content})
|
|||
|
|
return resolved
|
|||
|
|
|
|||
|
|
# Make the API call in handler scope — try/except inside async generators is unreliable
|
|||
|
|
# with Starlette's streaming machinery, so we resolve errors here before the generator starts.
|
|||
|
|
send_params = params
|
|||
|
|
if not is_ext_openai_endpoint(endpoint):
|
|||
|
|
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
|
|||
|
|
send_params = {**params, "messages": resolved_msgs}
|
|||
|
|
# Proactive trim: only for small-ctx models we've already seen run out of space
|
|||
|
|
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
|
|||
|
|
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
|
|||
|
|
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
|||
|
|
_pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2)
|
|||
|
|
_pre_est = _count_message_tokens(send_params.get("messages", []))
|
|||
|
|
if _pre_est > _pre_target:
|
|||
|
|
_pre_msgs = send_params.get("messages", [])
|
|||
|
|
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
|
|||
|
|
_dropped = len(_pre_msgs) - len(_pre_trimmed)
|
|||
|
|
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
|
|||
|
|
send_params = {**send_params, "messages": _pre_trimmed}
|
|||
|
|
try:
|
|||
|
|
async_gen = await oclient.chat.completions.create(**send_params)
|
|||
|
|
except Exception as e:
|
|||
|
|
_e_str = str(e)
|
|||
|
|
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
|
|||
|
|
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
|
|||
|
|
if "does not support tools" in _e_str:
|
|||
|
|
# Model doesn't support tools — retry without them
|
|||
|
|
print(f"[ochat] retry: no tools", flush=True)
|
|||
|
|
try:
|
|||
|
|
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
|
|||
|
|
async_gen = await oclient.chat.completions.create(**params_without_tools)
|
|||
|
|
except Exception:
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|
|||
|
|
raise
|
|||
|
|
elif _is_ctx_err:
|
|||
|
|
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
|
|||
|
|
err_body = getattr(e, "body", {}) or {}
|
|||
|
|
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
|||
|
|
n_ctx_limit = err_detail.get("n_ctx", 0)
|
|||
|
|
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
|||
|
|
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
|
|||
|
|
if not n_ctx_limit:
|
|||
|
|
import re as _re
|
|||
|
|
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
|
|||
|
|
if _m:
|
|||
|
|
n_ctx_limit = int(_m.group(1))
|
|||
|
|
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
|
|||
|
|
if _m:
|
|||
|
|
actual_tokens = int(_m.group(1))
|
|||
|
|
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
|
|||
|
|
if not n_ctx_limit:
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|
|||
|
|
raise
|
|||
|
|
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
|||
|
|
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
|||
|
|
|
|||
|
|
msgs_to_trim = send_params.get("messages", [])
|
|||
|
|
try:
|
|||
|
|
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
|||
|
|
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
|||
|
|
except Exception as _helper_exc:
|
|||
|
|
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|
|||
|
|
raise
|
|||
|
|
dropped = len(msgs_to_trim) - len(trimmed_messages)
|
|||
|
|
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
|
|||
|
|
try:
|
|||
|
|
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
|
|||
|
|
print(f"[ctx-trim] retry-1 ok", flush=True)
|
|||
|
|
except Exception as e2:
|
|||
|
|
_e2_str = str(e2)
|
|||
|
|
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
|
|||
|
|
# Still too large — tool definitions likely consuming too many tokens, strip them too
|
|||
|
|
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
|
|||
|
|
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
|
|||
|
|
try:
|
|||
|
|
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
|
|||
|
|
print(f"[ctx-trim] retry-2 ok", flush=True)
|
|||
|
|
except Exception:
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|
|||
|
|
raise
|
|||
|
|
else:
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|
|||
|
|
raise
|
|||
|
|
elif _is_backend_connection_error(e):
|
|||
|
|
# Upstream connection failed (e.g. llama-server in router mode
|
|||
|
|
# whose delegated worker died). Mark (endpoint, model) so the
|
|||
|
|
# next request reroutes; the client will retry this one.
|
|||
|
|
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
|||
|
|
await _mark_backend_unhealthy(endpoint, model, _e_str)
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|
|||
|
|
raise
|
|||
|
|
elif "image input is not supported" in _e_str:
|
|||
|
|
# Model doesn't support images — strip and retry
|
|||
|
|
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
|
|||
|
|
try:
|
|||
|
|
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
|
|||
|
|
except Exception:
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|
|||
|
|
raise
|
|||
|
|
else:
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
# 4. Async generator — only streams the already-established async_gen
|
|||
|
|
async def stream_ochat_response():
|
|||
|
|
try:
|
|||
|
|
if stream == True:
|
|||
|
|
content_parts: list[str] = []
|
|||
|
|
usage_snapshot: dict = {}
|
|||
|
|
async for chunk in async_gen:
|
|||
|
|
data = (
|
|||
|
|
chunk.model_dump_json()
|
|||
|
|
if hasattr(chunk, "model_dump_json")
|
|||
|
|
else orjson.dumps(chunk)
|
|||
|
|
)
|
|||
|
|
if chunk.choices:
|
|||
|
|
delta = chunk.choices[0].delta
|
|||
|
|
has_content = delta.content is not None
|
|||
|
|
has_reasoning = (
|
|||
|
|
getattr(delta, "reasoning_content", None) is not None
|
|||
|
|
or getattr(delta, "reasoning", None) is not None
|
|||
|
|
)
|
|||
|
|
has_tool_calls = getattr(delta, "tool_calls", None) is not None
|
|||
|
|
if has_content or has_reasoning or has_tool_calls:
|
|||
|
|
yield f"data: {data}\n\n".encode("utf-8")
|
|||
|
|
if has_content and delta.content:
|
|||
|
|
content_parts.append(delta.content)
|
|||
|
|
elif chunk.usage is not None:
|
|||
|
|
# Forward the usage-only final chunk (e.g. from llama-server)
|
|||
|
|
yield f"data: {data}\n\n".encode("utf-8")
|
|||
|
|
prompt_tok = 0
|
|||
|
|
comp_tok = 0
|
|||
|
|
if chunk.usage is not None:
|
|||
|
|
prompt_tok = chunk.usage.prompt_tokens or 0
|
|||
|
|
comp_tok = chunk.usage.completion_tokens or 0
|
|||
|
|
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
|
|||
|
|
else:
|
|||
|
|
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
|||
|
|
if llama_usage:
|
|||
|
|
prompt_tok, comp_tok = llama_usage
|
|||
|
|
if prompt_tok != 0 or comp_tok != 0:
|
|||
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
|||
|
|
# Detect context exhaustion mid-generation for small-ctx models.
|
|||
|
|
# Guard: skip if max_tokens was set in the request — finish_reason=length
|
|||
|
|
# could just mean the caller's token budget was exhausted, not the context window.
|
|||
|
|
_req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens")
|
|||
|
|
if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok:
|
|||
|
|
_inferred_nctx = (prompt_tok + comp_tok) or 0
|
|||
|
|
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
|||
|
|
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
|
|||
|
|
print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
|
|||
|
|
# Cache assembled streaming response — before [DONE] so it always runs
|
|||
|
|
if _cache is not None and _cache_enabled and content_parts:
|
|||
|
|
assembled = orjson.dumps({
|
|||
|
|
"model": model,
|
|||
|
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": "stop"}],
|
|||
|
|
**({"usage": usage_snapshot} if usage_snapshot else {}),
|
|||
|
|
}) + b"\n"
|
|||
|
|
try:
|
|||
|
|
await _cache.set_chat("openai_chat", model, messages, assembled)
|
|||
|
|
except Exception as _ce:
|
|||
|
|
print(f"[cache] set_chat (openai_chat streaming) failed: {_ce}")
|
|||
|
|
yield b"data: [DONE]\n\n"
|
|||
|
|
else:
|
|||
|
|
prompt_tok = 0
|
|||
|
|
comp_tok = 0
|
|||
|
|
if async_gen.usage is not None:
|
|||
|
|
prompt_tok = async_gen.usage.prompt_tokens or 0
|
|||
|
|
comp_tok = async_gen.usage.completion_tokens or 0
|
|||
|
|
else:
|
|||
|
|
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
|||
|
|
if llama_usage:
|
|||
|
|
prompt_tok, comp_tok = llama_usage
|
|||
|
|
if prompt_tok != 0 or comp_tok != 0:
|
|||
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
|||
|
|
json_line = (
|
|||
|
|
async_gen.model_dump_json()
|
|||
|
|
if hasattr(async_gen, "model_dump_json")
|
|||
|
|
else orjson.dumps(async_gen)
|
|||
|
|
)
|
|||
|
|
cache_bytes = json_line.encode("utf-8") + b"\n"
|
|||
|
|
yield cache_bytes
|
|||
|
|
# Cache non-streaming response
|
|||
|
|
if _cache is not None and _cache_enabled:
|
|||
|
|
try:
|
|||
|
|
await _cache.set_chat("openai_chat", model, messages, cache_bytes)
|
|||
|
|
except Exception as _ce:
|
|||
|
|
print(f"[cache] set_chat (openai_chat non-streaming) failed: {_ce}")
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# Ensure counter is decremented even if an exception occurs
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|
|||
|
|
|
|||
|
|
# 4. Return a StreamingResponse backed by the generator
|
|||
|
|
return StreamingResponse(
|
|||
|
|
stream_ochat_response(),
|
|||
|
|
media_type="text/event-stream" if stream else "application/json",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/v1/completions")
|
|||
|
|
async def openai_completions_proxy(request: Request):
|
|||
|
|
"""
|
|||
|
|
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
config = get_config()
|
|||
|
|
# 1. Parse and validate request
|
|||
|
|
try:
|
|||
|
|
body_bytes = await request.body()
|
|||
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|||
|
|
|
|||
|
|
model = payload.get("model")
|
|||
|
|
prompt = payload.get("prompt")
|
|||
|
|
frequency_penalty = payload.get("frequency_penalty")
|
|||
|
|
presence_penalty = payload.get("presence_penalty")
|
|||
|
|
seed = payload.get("seed")
|
|||
|
|
stop = payload.get("stop")
|
|||
|
|
stream = payload.get("stream")
|
|||
|
|
stream_options = payload.get("stream_options")
|
|||
|
|
temperature = payload.get("temperature")
|
|||
|
|
top_p = payload.get("top_p")
|
|||
|
|
max_tokens = payload.get("max_tokens")
|
|||
|
|
max_completion_tokens = payload.get("max_completion_tokens")
|
|||
|
|
suffix = payload.get("suffix")
|
|||
|
|
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
|||
|
|
|
|||
|
|
if not model:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=400, detail="Missing required field 'model'"
|
|||
|
|
)
|
|||
|
|
if not prompt:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=400, detail="Missing required field 'prompt'"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if ":latest" in model:
|
|||
|
|
model = model.split(":latest")
|
|||
|
|
model = model[0]
|
|||
|
|
|
|||
|
|
params = {
|
|||
|
|
"prompt": prompt,
|
|||
|
|
"model": model,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
optional_params = {
|
|||
|
|
"frequency_penalty": frequency_penalty,
|
|||
|
|
"presence_penalty": presence_penalty,
|
|||
|
|
"seed": seed,
|
|||
|
|
"stop": stop,
|
|||
|
|
"stream": stream,
|
|||
|
|
"stream_options": stream_options or {"include_usage": True },
|
|||
|
|
"temperature": temperature,
|
|||
|
|
"top_p": top_p,
|
|||
|
|
"max_tokens": max_tokens,
|
|||
|
|
"max_completion_tokens": max_completion_tokens,
|
|||
|
|
"suffix": suffix
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
|||
|
|
except orjson.JSONDecodeError as e:
|
|||
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|||
|
|
|
|||
|
|
# Cache lookup — completions prompt mapped to a single-turn messages list
|
|||
|
|
_cache = get_llm_cache()
|
|||
|
|
_compl_messages = [{"role": "user", "content": prompt}]
|
|||
|
|
if _cache is not None and _cache_enabled:
|
|||
|
|
_cached = await _cache.get_chat("openai_completions", model, _compl_messages)
|
|||
|
|
if _cached is not None:
|
|||
|
|
if stream:
|
|||
|
|
_sse = openai_nonstream_to_sse(_cached, model)
|
|||
|
|
async def _serve_cached_ocompl_stream():
|
|||
|
|
yield _sse
|
|||
|
|
return StreamingResponse(_serve_cached_ocompl_stream(), media_type="text/event-stream")
|
|||
|
|
else:
|
|||
|
|
async def _serve_cached_ocompl_json():
|
|||
|
|
yield _cached
|
|||
|
|
return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json")
|
|||
|
|
|
|||
|
|
# 2. Endpoint logic
|
|||
|
|
_affinity_key = _conversation_fingerprint(model, None, prompt)
|
|||
|
|
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
|||
|
|
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
|||
|
|
|
|||
|
|
# 3. Async generator that streams completions data and decrements the counter
|
|||
|
|
# Make the API call in handler scope (try/except inside async generators is unreliable)
|
|||
|
|
try:
|
|||
|
|
async_gen = await oclient.completions.create(**params)
|
|||
|
|
except Exception as e:
|
|||
|
|
if _is_backend_connection_error(e):
|
|||
|
|
print(f"[ocompl] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
|||
|
|
await _mark_backend_unhealthy(endpoint, model, str(e))
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
async def stream_ocompletions_response(model=model):
|
|||
|
|
try:
|
|||
|
|
if stream == True:
|
|||
|
|
text_parts: list[str] = []
|
|||
|
|
usage_snapshot: dict = {}
|
|||
|
|
async for chunk in async_gen:
|
|||
|
|
data = (
|
|||
|
|
chunk.model_dump_json()
|
|||
|
|
if hasattr(chunk, "model_dump_json")
|
|||
|
|
else orjson.dumps(chunk)
|
|||
|
|
)
|
|||
|
|
if chunk.choices:
|
|||
|
|
choice = chunk.choices[0]
|
|||
|
|
has_text = getattr(choice, "text", None) is not None
|
|||
|
|
has_reasoning = (
|
|||
|
|
getattr(choice, "reasoning_content", None) is not None
|
|||
|
|
or getattr(choice, "reasoning", None) is not None
|
|||
|
|
)
|
|||
|
|
if has_text or has_reasoning or choice.finish_reason is not None:
|
|||
|
|
yield f"data: {data}\n\n".encode("utf-8")
|
|||
|
|
if has_text and choice.text:
|
|||
|
|
text_parts.append(choice.text)
|
|||
|
|
elif chunk.usage is not None:
|
|||
|
|
# Forward the usage-only final chunk (e.g. from llama-server)
|
|||
|
|
yield f"data: {data}\n\n".encode("utf-8")
|
|||
|
|
prompt_tok = 0
|
|||
|
|
comp_tok = 0
|
|||
|
|
if chunk.usage is not None:
|
|||
|
|
prompt_tok = chunk.usage.prompt_tokens or 0
|
|||
|
|
comp_tok = chunk.usage.completion_tokens or 0
|
|||
|
|
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
|
|||
|
|
else:
|
|||
|
|
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
|||
|
|
if llama_usage:
|
|||
|
|
prompt_tok, comp_tok = llama_usage
|
|||
|
|
if prompt_tok != 0 or comp_tok != 0:
|
|||
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
|||
|
|
# Cache assembled streaming response — before [DONE] so it always runs
|
|||
|
|
if _cache is not None and _cache_enabled and text_parts:
|
|||
|
|
assembled = orjson.dumps({
|
|||
|
|
"model": model,
|
|||
|
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(text_parts)}, "finish_reason": "stop"}],
|
|||
|
|
**({"usage": usage_snapshot} if usage_snapshot else {}),
|
|||
|
|
}) + b"\n"
|
|||
|
|
try:
|
|||
|
|
await _cache.set_chat("openai_completions", model, _compl_messages, assembled)
|
|||
|
|
except Exception as _ce:
|
|||
|
|
print(f"[cache] set_chat (openai_completions streaming) failed: {_ce}")
|
|||
|
|
# Final DONE event
|
|||
|
|
yield b"data: [DONE]\n\n"
|
|||
|
|
else:
|
|||
|
|
prompt_tok = 0
|
|||
|
|
comp_tok = 0
|
|||
|
|
if async_gen.usage is not None:
|
|||
|
|
prompt_tok = async_gen.usage.prompt_tokens or 0
|
|||
|
|
comp_tok = async_gen.usage.completion_tokens or 0
|
|||
|
|
else:
|
|||
|
|
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
|||
|
|
if llama_usage:
|
|||
|
|
prompt_tok, comp_tok = llama_usage
|
|||
|
|
if prompt_tok != 0 or comp_tok != 0:
|
|||
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
|||
|
|
json_line = (
|
|||
|
|
async_gen.model_dump_json()
|
|||
|
|
if hasattr(async_gen, "model_dump_json")
|
|||
|
|
else orjson.dumps(async_gen)
|
|||
|
|
)
|
|||
|
|
cache_bytes = json_line.encode("utf-8") + b"\n"
|
|||
|
|
yield cache_bytes
|
|||
|
|
# Cache non-streaming response
|
|||
|
|
if _cache is not None and _cache_enabled:
|
|||
|
|
try:
|
|||
|
|
await _cache.set_chat("openai_completions", model, _compl_messages, cache_bytes)
|
|||
|
|
except Exception as _ce:
|
|||
|
|
print(f"[cache] set_chat (openai_completions non-streaming) failed: {_ce}")
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# Ensure counter is decremented even if an exception occurs
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|
|||
|
|
|
|||
|
|
# 4. Return a StreamingResponse backed by the generator
|
|||
|
|
return StreamingResponse(
|
|||
|
|
stream_ocompletions_response(),
|
|||
|
|
media_type="text/event-stream" if stream else "application/json",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/v1/models")
|
|||
|
|
async def openai_models_proxy(request: Request):
|
|||
|
|
"""
|
|||
|
|
Proxy an OpenAI API models request to Ollama and llama-server endpoints and reply with a unique list of models.
|
|||
|
|
|
|||
|
|
For Ollama endpoints: queries /api/tags (all models)
|
|||
|
|
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
|
|||
|
|
"""
|
|||
|
|
config = get_config()
|
|||
|
|
# 1. Query Ollama endpoints for all models via /api/tags
|
|||
|
|
ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
|
|||
|
|
# 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models
|
|||
|
|
ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)]
|
|||
|
|
# 3. Query llama-server endpoints for loaded models via /v1/models
|
|||
|
|
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
|||
|
|
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
|||
|
|
llama_tasks = [
|
|||
|
|
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
|
|||
|
|
for ep in all_llama_endpoints
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
ollama_models = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
|
|||
|
|
ext_openai_models = await asyncio.gather(*ext_openai_tasks) if ext_openai_tasks else []
|
|||
|
|
llama_models = await asyncio.gather(*llama_tasks) if llama_tasks else []
|
|||
|
|
|
|||
|
|
models = {'data': []}
|
|||
|
|
|
|||
|
|
# Add Ollama models (if any)
|
|||
|
|
if ollama_models:
|
|||
|
|
for modellist in ollama_models:
|
|||
|
|
for model in modellist:
|
|||
|
|
if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
|
|||
|
|
model['id'] = model.get('name', model.get('id', ''))
|
|||
|
|
else:
|
|||
|
|
model['name'] = model['id']
|
|||
|
|
models['data'].append(model)
|
|||
|
|
|
|||
|
|
# Add external OpenAI models (if any)
|
|||
|
|
if ext_openai_models:
|
|||
|
|
for modellist in ext_openai_models:
|
|||
|
|
for model in modellist:
|
|||
|
|
if not "id" in model.keys():
|
|||
|
|
model['id'] = model.get('name', model.get('id', ''))
|
|||
|
|
else:
|
|||
|
|
model['name'] = model['id']
|
|||
|
|
models['data'].append(model)
|
|||
|
|
|
|||
|
|
# Add llama-server models (all available, not just loaded)
|
|||
|
|
if llama_models:
|
|||
|
|
for modellist in llama_models:
|
|||
|
|
for model in modellist:
|
|||
|
|
if not "id" in model.keys():
|
|||
|
|
model['id'] = model.get('name', model.get('id', ''))
|
|||
|
|
else:
|
|||
|
|
model['name'] = model['id']
|
|||
|
|
models['data'].append(model)
|
|||
|
|
|
|||
|
|
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
|||
|
|
return JSONResponse(
|
|||
|
|
content={"data": dedupe_on_keys(models['data'], ['name'])},
|
|||
|
|
status_code=200,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/v1/rerank")
|
|||
|
|
@router.post("/rerank")
|
|||
|
|
async def rerank_proxy(request: Request):
|
|||
|
|
"""
|
|||
|
|
Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
|
|||
|
|
|
|||
|
|
Compatible with the Jina/Cohere rerank API convention used by llama-server,
|
|||
|
|
vLLM, and services such as Cohere and Jina AI.
|
|||
|
|
|
|||
|
|
Ollama does not natively support reranking; requests routed to a plain Ollama
|
|||
|
|
endpoint will receive a 501 Not Implemented response.
|
|||
|
|
|
|||
|
|
Request body:
|
|||
|
|
model (str, required) – reranker model name
|
|||
|
|
query (str, required) – search query
|
|||
|
|
documents (list[str], required) – candidate documents to rank
|
|||
|
|
top_n (int, optional) – limit returned results (default: all)
|
|||
|
|
return_documents (bool, optional) – include document text in results
|
|||
|
|
max_tokens_per_doc (int, optional) – truncation limit per document
|
|||
|
|
|
|||
|
|
Response (Jina/Cohere-compatible):
|
|||
|
|
{
|
|||
|
|
"id": "...",
|
|||
|
|
"model": "...",
|
|||
|
|
"usage": {"prompt_tokens": N, "total_tokens": N},
|
|||
|
|
"results": [{"index": 0, "relevance_score": 0.95}, ...]
|
|||
|
|
}
|
|||
|
|
"""
|
|||
|
|
config = get_config()
|
|||
|
|
try:
|
|||
|
|
body_bytes = await request.body()
|
|||
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|||
|
|
|
|||
|
|
model = payload.get("model")
|
|||
|
|
query = payload.get("query")
|
|||
|
|
documents = payload.get("documents")
|
|||
|
|
|
|||
|
|
if not model:
|
|||
|
|
raise HTTPException(status_code=400, detail="Missing required field 'model'")
|
|||
|
|
if not query:
|
|||
|
|
raise HTTPException(status_code=400, detail="Missing required field 'query'")
|
|||
|
|
if not isinstance(documents, list) or not documents:
|
|||
|
|
raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
|
|||
|
|
except orjson.JSONDecodeError as e:
|
|||
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|||
|
|
|
|||
|
|
# Determine which endpoint serves this model
|
|||
|
|
try:
|
|||
|
|
endpoint, tracking_model = await choose_endpoint(model)
|
|||
|
|
except RuntimeError as e:
|
|||
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|||
|
|
|
|||
|
|
# Ollama endpoints have no native rerank support
|
|||
|
|
if not is_openai_compatible(endpoint):
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=501,
|
|||
|
|
detail=(
|
|||
|
|
f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
|
|||
|
|
"reranking. Use a llama-server or OpenAI-compatible endpoint with a "
|
|||
|
|
"dedicated reranker model."
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if ":latest" in model:
|
|||
|
|
model = model.split(":latest")[0]
|
|||
|
|
|
|||
|
|
# Build upstream rerank request body – forward only recognised fields
|
|||
|
|
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
|
|||
|
|
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
|
|||
|
|
if optional_key in payload:
|
|||
|
|
upstream_payload[optional_key] = payload[optional_key]
|
|||
|
|
|
|||
|
|
# Determine upstream URL:
|
|||
|
|
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
|
|||
|
|
# External OpenAI endpoints expose /rerank under their /v1 base
|
|||
|
|
if endpoint in config.llama_server_endpoints:
|
|||
|
|
# llama-server: endpoint may or may not already contain /v1
|
|||
|
|
if "/v1" in endpoint:
|
|||
|
|
rerank_url = f"{endpoint}/rerank"
|
|||
|
|
else:
|
|||
|
|
rerank_url = f"{endpoint}/v1/rerank"
|
|||
|
|
else:
|
|||
|
|
# External OpenAI-compatible: ep2base gives us the /v1 base
|
|||
|
|
rerank_url = f"{ep2base(endpoint)}/rerank"
|
|||
|
|
|
|||
|
|
api_key = config.api_keys.get(endpoint, "no-key")
|
|||
|
|
headers = {
|
|||
|
|
"Content-Type": "application/json",
|
|||
|
|
"Authorization": f"Bearer {api_key}",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
client: aiohttp.ClientSession = get_session(endpoint)
|
|||
|
|
try:
|
|||
|
|
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
|
|||
|
|
response_bytes = await resp.read()
|
|||
|
|
if resp.status >= 400:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=resp.status,
|
|||
|
|
detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
|
|||
|
|
)
|
|||
|
|
data = orjson.loads(response_bytes)
|
|||
|
|
|
|||
|
|
# Record token usage if the upstream returned a usage object
|
|||
|
|
usage = data.get("usage") or {}
|
|||
|
|
prompt_tok = usage.get("prompt_tokens") or 0
|
|||
|
|
total_tok = usage.get("total_tokens") or 0
|
|||
|
|
# For reranking there are no completion tokens; we record prompt tokens only
|
|||
|
|
if prompt_tok or total_tok:
|
|||
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
|
|||
|
|
|
|||
|
|
return JSONResponse(content=data)
|
|||
|
|
finally:
|
|||
|
|
await decrement_usage(endpoint, tracking_model)
|