"""OpenAI-compatible routes (``/v1/embeddings``, ``/v1/chat/completions``, ``/v1/completions``, ``/v1/models``, ``/v1/rerank`` and ``/rerank``). The chat-completions and completions handlers carry the full reactive-trim logic for ``exceed_context_size_error`` plus connection-failure rerouting (``_mark_backend_unhealthy``). The streaming branches assemble cached responses on the fly so caching works for both streaming and non-streaming clients. """ import asyncio import base64 import math import aiohttp import orjson from fastapi import APIRouter, HTTPException, Request from starlette.responses import JSONResponse, StreamingResponse from cache import get_llm_cache, openai_nonstream_to_sse from config import get_config from context_window import ( _count_message_tokens, _trim_messages_for_context, _calibrated_trim_target, _endpoint_nctx, _CTX_TRIM_SMALL_LIMIT, ) from fingerprint import _conversation_fingerprint from security import _mask_secrets from state import token_queue, app_state, default_headers from backends.health import _is_backend_connection_error, _mark_backend_unhealthy from backends.normalize import ( dedupe_on_keys, ep2base, is_ext_openai_endpoint, is_openai_compatible, _normalize_llama_model_name, ) from backends.probe import fetch from backends.sessions import _make_openai_client, get_session from requests.messages import _strip_assistant_prefill, _strip_images_from_messages from requests.rechunk import rechunk from routing import choose_endpoint, decrement_usage router = APIRouter() @router.post("/v1/embeddings") async def openai_embedding_proxy(request: Request): """ Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings. """ config = get_config() # 1. Parse and validate request try: body_bytes = await request.body() payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") doc = payload.get("input") # Normalize multimodal input: extract only text parts for embedding models if isinstance(doc, list): normalized = [] for item in doc: if isinstance(item, dict): # Multimodal content part - extract text only, skip images if item.get("type") == "text": normalized.append(item.get("text", "")) # Skip image_url and other non-text types else: normalized.append(item) doc = normalized if len(normalized) != 1 else normalized[0] elif isinstance(doc, dict) and doc.get("type") == "text": doc = doc.get("text", "") if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) if not doc: raise HTTPException( status_code=400, detail="Missing required field 'input'" ) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic endpoint, tracking_model = await choose_endpoint(model) if is_openai_compatible(endpoint): api_key = config.api_keys.get(endpoint, "no-key") else: api_key = "ollama" oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key) try: async_gen = await oclient.embeddings.create(input=doc, model=model) result = async_gen.model_dump() for item in result.get("data", []): emb = item.get("embedding") if emb: item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb] return JSONResponse(content=result) finally: await decrement_usage(endpoint, tracking_model) @router.post("/v1/chat/completions") async def openai_chat_completions_proxy(request: Request): """ Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response. """ config = get_config() # 1. Parse and validate request try: body_bytes = await request.body() payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") messages = payload.get("messages") frequency_penalty = payload.get("frequency_penalty") presence_penalty = payload.get("presence_penalty") response_format = payload.get("response_format") seed = payload.get("seed") stop = payload.get("stop") stream = payload.get("stream") stream_options = payload.get("stream_options") temperature = payload.get("temperature") top_p = payload.get("top_p") max_tokens = payload.get("max_tokens") max_completion_tokens = payload.get("max_completion_tokens") tools = payload.get("tools") logprobs = payload.get("logprobs") top_logprobs = payload.get("top_logprobs") _cache_enabled = payload.get("nomyo", {}).get("cache", False) if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) if not isinstance(messages, list): raise HTTPException( status_code=400, detail="Missing required field 'messages' (must be a list)" ) if ":latest" in model: model = model.split(":latest") model = model[0] messages = _strip_assistant_prefill(messages) params = { "messages": messages, "model": model, } optional_params = { "tools": tools, "response_format": response_format, "stream_options": stream_options or {"include_usage": True }, "max_completion_tokens": max_completion_tokens, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, "seed": seed, "presence_penalty": presence_penalty, "frequency_penalty": frequency_penalty, "stop": stop, "stream": stream, "logprobs": logprobs, "top_logprobs": top_logprobs, } params.update({k: v for k, v in optional_params.items() if v is not None}) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # Reject unsupported image formats (SVG) before doing any work for _msg in messages: for _item in (_msg.get("content") or []) if isinstance(_msg.get("content"), list) else []: if _item.get("type") == "image_url": _url = (_item.get("image_url") or {}).get("url", "") if _url.startswith("data:image/svg") or _url.lower().endswith(".svg"): raise HTTPException( status_code=400, detail="SVG images are not supported. Please convert the image to PNG or JPEG before sending.", ) # Cache lookup — before endpoint selection _cache = get_llm_cache() if _cache is not None and _cache_enabled: _cached = await _cache.get_chat("openai_chat", model, messages) if _cached is not None: if stream: _sse = openai_nonstream_to_sse(_cached, model) async def _serve_cached_ochat_stream(): yield _sse return StreamingResponse(_serve_cached_ochat_stream(), media_type="text/event-stream") else: async def _serve_cached_ochat_json(): yield _cached return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json") # 2. Endpoint logic _affinity_key = _conversation_fingerprint(model, messages, None) endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Helpers and API call — done in handler scope so try/except works reliably async def _normalize_images_in_messages(msgs: list) -> list: """Fetch remote image URLs and convert them to base64 data URLs so Ollama/llama-server can handle them without making outbound HTTP requests.""" resolved = [] for msg in msgs: content = msg.get("content") if not isinstance(content, list): resolved.append(msg) continue new_content = [] for item in content: if item.get("type") == "image_url": url = (item.get("image_url") or {}).get("url", "") if url and not url.startswith("data:"): try: http: aiohttp.ClientSession = app_state["session"] async with http.get(url) as resp: ctype = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip() img_bytes = await resp.read() b64 = base64.b64encode(img_bytes).decode("utf-8") new_content.append({ "type": "image_url", "image_url": {"url": f"data:{ctype};base64,{b64}"} }) except Exception as _ie: print(f"[image] Failed to fetch image URL: {_ie}") new_content.append(item) else: new_content.append(item) else: new_content.append(item) resolved.append({**msg, "content": new_content}) return resolved # Make the API call in handler scope — try/except inside async generators is unreliable # with Starlette's streaming machinery, so we resolve errors here before the generator starts. send_params = params if not is_ext_openai_endpoint(endpoint): resolved_msgs = await _normalize_images_in_messages(params.get("messages", [])) send_params = {**params, "messages": resolved_msgs} # Proactive trim: only for small-ctx models we've already seen run out of space _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: _pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2) _pre_est = _count_message_tokens(send_params.get("messages", [])) if _pre_est > _pre_target: _pre_msgs = send_params.get("messages", []) _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) _dropped = len(_pre_msgs) - len(_pre_trimmed) print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) send_params = {**send_params, "messages": _pre_trimmed} try: async_gen = await oclient.chat.completions.create(**send_params) except Exception as e: _e_str = str(e) _is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True) if "does not support tools" in _e_str: # Model doesn't support tools — retry without them print(f"[ochat] retry: no tools", flush=True) try: params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} async_gen = await oclient.chat.completions.create(**params_without_tools) except Exception: await decrement_usage(endpoint, tracking_model) raise elif _is_ctx_err: # Backend context limit hit — apply sliding-window trim (context-shift at message level) err_body = getattr(e, "body", {}) or {} err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} n_ctx_limit = err_detail.get("n_ctx", 0) actual_tokens = err_detail.get("n_prompt_tokens", 0) # Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors) if not n_ctx_limit: import re as _re _m = _re.search(r"'n_ctx':\s*(\d+)", _e_str) if _m: n_ctx_limit = int(_m.group(1)) _m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str) if _m: actual_tokens = int(_m.group(1)) print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True) if not n_ctx_limit: await decrement_usage(endpoint, tracking_model) raise if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: _endpoint_nctx[(endpoint, model)] = n_ctx_limit msgs_to_trim = send_params.get("messages", []) try: cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) except Exception as _helper_exc: print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True) await decrement_usage(endpoint, tracking_model) raise dropped = len(msgs_to_trim) - len(trimmed_messages) print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True) try: async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages}) print(f"[ctx-trim] retry-1 ok", flush=True) except Exception as e2: _e2_str = str(e2) if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: # Still too large — tool definitions likely consuming too many tokens, strip them too print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True) params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")} try: async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages}) print(f"[ctx-trim] retry-2 ok", flush=True) except Exception: await decrement_usage(endpoint, tracking_model) raise else: await decrement_usage(endpoint, tracking_model) raise elif _is_backend_connection_error(e): # Upstream connection failed (e.g. llama-server in router mode # whose delegated worker died). Mark (endpoint, model) so the # next request reroutes; the client will retry this one. print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) await _mark_backend_unhealthy(endpoint, model, _e_str) await decrement_usage(endpoint, tracking_model) raise elif "image input is not supported" in _e_str: # Model doesn't support images — strip and retry print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages") try: async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))}) except Exception: await decrement_usage(endpoint, tracking_model) raise else: await decrement_usage(endpoint, tracking_model) raise # 4. Async generator — only streams the already-established async_gen async def stream_ochat_response(): try: if stream == True: content_parts: list[str] = [] usage_snapshot: dict = {} async for chunk in async_gen: data = ( chunk.model_dump_json() if hasattr(chunk, "model_dump_json") else orjson.dumps(chunk) ) if chunk.choices: delta = chunk.choices[0].delta has_content = delta.content is not None has_reasoning = ( getattr(delta, "reasoning_content", None) is not None or getattr(delta, "reasoning", None) is not None ) has_tool_calls = getattr(delta, "tool_calls", None) is not None if has_content or has_reasoning or has_tool_calls: yield f"data: {data}\n\n".encode("utf-8") if has_content and delta.content: content_parts.append(delta.content) elif chunk.usage is not None: # Forward the usage-only final chunk (e.g. from llama-server) yield f"data: {data}\n\n".encode("utf-8") prompt_tok = 0 comp_tok = 0 if chunk.usage is not None: prompt_tok = chunk.usage.prompt_tokens or 0 comp_tok = chunk.usage.completion_tokens or 0 usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok} else: llama_usage = rechunk.extract_usage_from_llama_timings(chunk) if llama_usage: prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) # Detect context exhaustion mid-generation for small-ctx models. # Guard: skip if max_tokens was set in the request — finish_reason=length # could just mean the caller's token budget was exhausted, not the context window. _req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens") if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok: _inferred_nctx = (prompt_tok + comp_tok) or 0 if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT: _endpoint_nctx[(endpoint, model)] = _inferred_nctx print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True) # Cache assembled streaming response — before [DONE] so it always runs if _cache is not None and _cache_enabled and content_parts: assembled = orjson.dumps({ "model": model, "choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": "stop"}], **({"usage": usage_snapshot} if usage_snapshot else {}), }) + b"\n" try: await _cache.set_chat("openai_chat", model, messages, assembled) except Exception as _ce: print(f"[cache] set_chat (openai_chat streaming) failed: {_ce}") yield b"data: [DONE]\n\n" else: prompt_tok = 0 comp_tok = 0 if async_gen.usage is not None: prompt_tok = async_gen.usage.prompt_tokens or 0 comp_tok = async_gen.usage.completion_tokens or 0 else: llama_usage = rechunk.extract_usage_from_llama_timings(async_gen) if llama_usage: prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) json_line = ( async_gen.model_dump_json() if hasattr(async_gen, "model_dump_json") else orjson.dumps(async_gen) ) cache_bytes = json_line.encode("utf-8") + b"\n" yield cache_bytes # Cache non-streaming response if _cache is not None and _cache_enabled: try: await _cache.set_chat("openai_chat", model, messages, cache_bytes) except Exception as _ce: print(f"[cache] set_chat (openai_chat non-streaming) failed: {_ce}") finally: # Ensure counter is decremented even if an exception occurs await decrement_usage(endpoint, tracking_model) # 4. Return a StreamingResponse backed by the generator return StreamingResponse( stream_ochat_response(), media_type="text/event-stream" if stream else "application/json", ) @router.post("/v1/completions") async def openai_completions_proxy(request: Request): """ Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response. """ config = get_config() # 1. Parse and validate request try: body_bytes = await request.body() payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") prompt = payload.get("prompt") frequency_penalty = payload.get("frequency_penalty") presence_penalty = payload.get("presence_penalty") seed = payload.get("seed") stop = payload.get("stop") stream = payload.get("stream") stream_options = payload.get("stream_options") temperature = payload.get("temperature") top_p = payload.get("top_p") max_tokens = payload.get("max_tokens") max_completion_tokens = payload.get("max_completion_tokens") suffix = payload.get("suffix") _cache_enabled = payload.get("nomyo", {}).get("cache", False) if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) if not prompt: raise HTTPException( status_code=400, detail="Missing required field 'prompt'" ) if ":latest" in model: model = model.split(":latest") model = model[0] params = { "prompt": prompt, "model": model, } optional_params = { "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "seed": seed, "stop": stop, "stream": stream, "stream_options": stream_options or {"include_usage": True }, "temperature": temperature, "top_p": top_p, "max_tokens": max_tokens, "max_completion_tokens": max_completion_tokens, "suffix": suffix } params.update({k: v for k, v in optional_params.items() if v is not None}) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # Cache lookup — completions prompt mapped to a single-turn messages list _cache = get_llm_cache() _compl_messages = [{"role": "user", "content": prompt}] if _cache is not None and _cache_enabled: _cached = await _cache.get_chat("openai_completions", model, _compl_messages) if _cached is not None: if stream: _sse = openai_nonstream_to_sse(_cached, model) async def _serve_cached_ocompl_stream(): yield _sse return StreamingResponse(_serve_cached_ocompl_stream(), media_type="text/event-stream") else: async def _serve_cached_ocompl_json(): yield _cached return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json") # 2. Endpoint logic _affinity_key = _conversation_fingerprint(model, None, prompt) endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Async generator that streams completions data and decrements the counter # Make the API call in handler scope (try/except inside async generators is unreliable) try: async_gen = await oclient.completions.create(**params) except Exception as e: if _is_backend_connection_error(e): print(f"[ocompl] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) await _mark_backend_unhealthy(endpoint, model, str(e)) await decrement_usage(endpoint, tracking_model) raise async def stream_ocompletions_response(model=model): try: if stream == True: text_parts: list[str] = [] usage_snapshot: dict = {} async for chunk in async_gen: data = ( chunk.model_dump_json() if hasattr(chunk, "model_dump_json") else orjson.dumps(chunk) ) if chunk.choices: choice = chunk.choices[0] has_text = getattr(choice, "text", None) is not None has_reasoning = ( getattr(choice, "reasoning_content", None) is not None or getattr(choice, "reasoning", None) is not None ) if has_text or has_reasoning or choice.finish_reason is not None: yield f"data: {data}\n\n".encode("utf-8") if has_text and choice.text: text_parts.append(choice.text) elif chunk.usage is not None: # Forward the usage-only final chunk (e.g. from llama-server) yield f"data: {data}\n\n".encode("utf-8") prompt_tok = 0 comp_tok = 0 if chunk.usage is not None: prompt_tok = chunk.usage.prompt_tokens or 0 comp_tok = chunk.usage.completion_tokens or 0 usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok} else: llama_usage = rechunk.extract_usage_from_llama_timings(chunk) if llama_usage: prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) # Cache assembled streaming response — before [DONE] so it always runs if _cache is not None and _cache_enabled and text_parts: assembled = orjson.dumps({ "model": model, "choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(text_parts)}, "finish_reason": "stop"}], **({"usage": usage_snapshot} if usage_snapshot else {}), }) + b"\n" try: await _cache.set_chat("openai_completions", model, _compl_messages, assembled) except Exception as _ce: print(f"[cache] set_chat (openai_completions streaming) failed: {_ce}") # Final DONE event yield b"data: [DONE]\n\n" else: prompt_tok = 0 comp_tok = 0 if async_gen.usage is not None: prompt_tok = async_gen.usage.prompt_tokens or 0 comp_tok = async_gen.usage.completion_tokens or 0 else: llama_usage = rechunk.extract_usage_from_llama_timings(async_gen) if llama_usage: prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) json_line = ( async_gen.model_dump_json() if hasattr(async_gen, "model_dump_json") else orjson.dumps(async_gen) ) cache_bytes = json_line.encode("utf-8") + b"\n" yield cache_bytes # Cache non-streaming response if _cache is not None and _cache_enabled: try: await _cache.set_chat("openai_completions", model, _compl_messages, cache_bytes) except Exception as _ce: print(f"[cache] set_chat (openai_completions non-streaming) failed: {_ce}") finally: # Ensure counter is decremented even if an exception occurs await decrement_usage(endpoint, tracking_model) # 4. Return a StreamingResponse backed by the generator return StreamingResponse( stream_ocompletions_response(), media_type="text/event-stream" if stream else "application/json", ) @router.get("/v1/models") async def openai_models_proxy(request: Request): """ Proxy an OpenAI API models request to Ollama and llama-server endpoints and reply with a unique list of models. For Ollama endpoints: queries /api/tags (all models) For llama-server endpoints: queries /v1/models and filters for status.value == "loaded" """ config = get_config() # 1. Query Ollama endpoints for all models via /api/tags ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] # 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)] # 3. Query llama-server endpoints for loaded models via /v1/models # Also query endpoints from llama_server_endpoints that may not be in config.endpoints all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints) llama_tasks = [ fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in all_llama_endpoints ] ollama_models = await asyncio.gather(*ollama_tasks) if ollama_tasks else [] ext_openai_models = await asyncio.gather(*ext_openai_tasks) if ext_openai_tasks else [] llama_models = await asyncio.gather(*llama_tasks) if llama_tasks else [] models = {'data': []} # Add Ollama models (if any) if ollama_models: for modellist in ollama_models: for model in modellist: if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name model['id'] = model.get('name', model.get('id', '')) else: model['name'] = model['id'] models['data'].append(model) # Add external OpenAI models (if any) if ext_openai_models: for modellist in ext_openai_models: for model in modellist: if not "id" in model.keys(): model['id'] = model.get('name', model.get('id', '')) else: model['name'] = model['id'] models['data'].append(model) # Add llama-server models (all available, not just loaded) if llama_models: for modellist in llama_models: for model in modellist: if not "id" in model.keys(): model['id'] = model.get('name', model.get('id', '')) else: model['name'] = model['id'] models['data'].append(model) # 2. Return a JSONResponse with a deduplicated list of unique models for inference return JSONResponse( content={"data": dedupe_on_keys(models['data'], ['name'])}, status_code=200, ) @router.post("/v1/rerank") @router.post("/rerank") async def rerank_proxy(request: Request): """ Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint. Compatible with the Jina/Cohere rerank API convention used by llama-server, vLLM, and services such as Cohere and Jina AI. Ollama does not natively support reranking; requests routed to a plain Ollama endpoint will receive a 501 Not Implemented response. Request body: model (str, required) – reranker model name query (str, required) – search query documents (list[str], required) – candidate documents to rank top_n (int, optional) – limit returned results (default: all) return_documents (bool, optional) – include document text in results max_tokens_per_doc (int, optional) – truncation limit per document Response (Jina/Cohere-compatible): { "id": "...", "model": "...", "usage": {"prompt_tokens": N, "total_tokens": N}, "results": [{"index": 0, "relevance_score": 0.95}, ...] } """ config = get_config() try: body_bytes = await request.body() payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") query = payload.get("query") documents = payload.get("documents") if not model: raise HTTPException(status_code=400, detail="Missing required field 'model'") if not query: raise HTTPException(status_code=400, detail="Missing required field 'query'") if not isinstance(documents, list) or not documents: raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)") except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # Determine which endpoint serves this model try: endpoint, tracking_model = await choose_endpoint(model) except RuntimeError as e: raise HTTPException(status_code=404, detail=str(e)) # Ollama endpoints have no native rerank support if not is_openai_compatible(endpoint): await decrement_usage(endpoint, tracking_model) raise HTTPException( status_code=501, detail=( f"Endpoint '{endpoint}' is a plain Ollama instance which does not support " "reranking. Use a llama-server or OpenAI-compatible endpoint with a " "dedicated reranker model." ), ) if ":latest" in model: model = model.split(":latest")[0] # Build upstream rerank request body – forward only recognised fields upstream_payload: dict = {"model": model, "query": query, "documents": documents} for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"): if optional_key in payload: upstream_payload[optional_key] = payload[optional_key] # Determine upstream URL: # llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints) # External OpenAI endpoints expose /rerank under their /v1 base if endpoint in config.llama_server_endpoints: # llama-server: endpoint may or may not already contain /v1 if "/v1" in endpoint: rerank_url = f"{endpoint}/rerank" else: rerank_url = f"{endpoint}/v1/rerank" else: # External OpenAI-compatible: ep2base gives us the /v1 base rerank_url = f"{ep2base(endpoint)}/rerank" api_key = config.api_keys.get(endpoint, "no-key") headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", } client: aiohttp.ClientSession = get_session(endpoint) try: async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp: response_bytes = await resp.read() if resp.status >= 400: raise HTTPException( status_code=resp.status, detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")), ) data = orjson.loads(response_bytes) # Record token usage if the upstream returned a usage object usage = data.get("usage") or {} prompt_tok = usage.get("prompt_tokens") or 0 total_tok = usage.get("total_tokens") or 0 # For reranking there are no completion tokens; we record prompt tokens only if prompt_tok or total_tok: await token_queue.put((endpoint, tracking_model, prompt_tok, 0)) return JSONResponse(content=data) finally: await decrement_usage(endpoint, tracking_model)