diff --git a/router.py b/router.py index ec495aa..442721a 100644 --- a/router.py +++ b/router.py @@ -3004,7 +3004,127 @@ async def openai_models_proxy(request: Request): ) # ------------------------------------------------------------- -# 25. Serve the static front‑end +# 25. API route – OpenAI/Jina/Cohere compatible Rerank +# ------------------------------------------------------------- +@app.post("/v1/rerank") +@app.post("/rerank") +async def rerank_proxy(request: Request): + """ + Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint. + + Compatible with the Jina/Cohere rerank API convention used by llama-server, + vLLM, and services such as Cohere and Jina AI. + + Ollama does not natively support reranking; requests routed to a plain Ollama + endpoint will receive a 501 Not Implemented response. + + Request body: + model (str, required) – reranker model name + query (str, required) – search query + documents (list[str], required) – candidate documents to rank + top_n (int, optional) – limit returned results (default: all) + return_documents (bool, optional) – include document text in results + max_tokens_per_doc (int, optional) – truncation limit per document + + Response (Jina/Cohere-compatible): + { + "id": "...", + "model": "...", + "usage": {"prompt_tokens": N, "total_tokens": N}, + "results": [{"index": 0, "relevance_score": 0.95}, ...] + } + """ + try: + body_bytes = await request.body() + payload = orjson.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + query = payload.get("query") + documents = payload.get("documents") + + if not model: + raise HTTPException(status_code=400, detail="Missing required field 'model'") + if not query: + raise HTTPException(status_code=400, detail="Missing required field 'query'") + if not isinstance(documents, list) or not documents: + raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)") + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # Determine which endpoint serves this model + try: + endpoint = await choose_endpoint(model) + except RuntimeError as e: + raise HTTPException(status_code=404, detail=str(e)) + + # Ollama endpoints have no native rerank support + if not is_openai_compatible(endpoint): + raise HTTPException( + status_code=501, + detail=( + f"Endpoint '{endpoint}' is a plain Ollama instance which does not support " + "reranking. Use a llama-server or OpenAI-compatible endpoint with a " + "dedicated reranker model." + ), + ) + + # Normalize model name for tracking + tracking_model = get_tracking_model(endpoint, model) + if ":latest" in model: + model = model.split(":latest")[0] + + await increment_usage(endpoint, tracking_model) + + # Build upstream rerank request body – forward only recognised fields + upstream_payload: dict = {"model": model, "query": query, "documents": documents} + for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"): + if optional_key in payload: + upstream_payload[optional_key] = payload[optional_key] + + # Determine upstream URL: + # llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints) + # External OpenAI endpoints expose /rerank under their /v1 base + if endpoint in config.llama_server_endpoints: + # llama-server: endpoint may or may not already contain /v1 + if "/v1" in endpoint: + rerank_url = f"{endpoint}/rerank" + else: + rerank_url = f"{endpoint}/v1/rerank" + else: + # External OpenAI-compatible: ep2base gives us the /v1 base + rerank_url = f"{ep2base(endpoint)}/rerank" + + api_key = config.api_keys.get(endpoint, "no-key") + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + + client: aiohttp.ClientSession = app_state["session"] + try: + async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp: + response_bytes = await resp.read() + if resp.status >= 400: + raise HTTPException( + status_code=resp.status, + detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")), + ) + data = orjson.loads(response_bytes) + + # Record token usage if the upstream returned a usage object + usage = data.get("usage") or {} + prompt_tok = usage.get("prompt_tokens") or 0 + total_tok = usage.get("total_tokens") or 0 + # For reranking there are no completion tokens; we record prompt tokens only + if prompt_tok or total_tok: + await token_queue.put((endpoint, tracking_model, prompt_tok, 0)) + + return JSONResponse(content=data) + finally: + await decrement_usage(endpoint, tracking_model) + +# ------------------------------------------------------------- +# 26. Serve the static front‑end # ------------------------------------------------------------- app.mount("/static", StaticFiles(directory="static"), name="static")