feat: adding /v1/rerank endpoint with cohere,jina,llama.cpp compatibility
This commit is contained in:
parent
ad4a1d07b2
commit
cac0580eec
1 changed files with 121 additions and 1 deletions
122
router.py
122
router.py
|
|
@ -3004,7 +3004,127 @@ async def openai_models_proxy(request: Request):
|
||||||
)
|
)
|
||||||
|
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
# 25. Serve the static front‑end
|
# 25. API route – OpenAI/Jina/Cohere compatible Rerank
|
||||||
|
# -------------------------------------------------------------
|
||||||
|
@app.post("/v1/rerank")
|
||||||
|
@app.post("/rerank")
|
||||||
|
async def rerank_proxy(request: Request):
|
||||||
|
"""
|
||||||
|
Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
|
||||||
|
|
||||||
|
Compatible with the Jina/Cohere rerank API convention used by llama-server,
|
||||||
|
vLLM, and services such as Cohere and Jina AI.
|
||||||
|
|
||||||
|
Ollama does not natively support reranking; requests routed to a plain Ollama
|
||||||
|
endpoint will receive a 501 Not Implemented response.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
model (str, required) – reranker model name
|
||||||
|
query (str, required) – search query
|
||||||
|
documents (list[str], required) – candidate documents to rank
|
||||||
|
top_n (int, optional) – limit returned results (default: all)
|
||||||
|
return_documents (bool, optional) – include document text in results
|
||||||
|
max_tokens_per_doc (int, optional) – truncation limit per document
|
||||||
|
|
||||||
|
Response (Jina/Cohere-compatible):
|
||||||
|
{
|
||||||
|
"id": "...",
|
||||||
|
"model": "...",
|
||||||
|
"usage": {"prompt_tokens": N, "total_tokens": N},
|
||||||
|
"results": [{"index": 0, "relevance_score": 0.95}, ...]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
body_bytes = await request.body()
|
||||||
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
|
model = payload.get("model")
|
||||||
|
query = payload.get("query")
|
||||||
|
documents = payload.get("documents")
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing required field 'model'")
|
||||||
|
if not query:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing required field 'query'")
|
||||||
|
if not isinstance(documents, list) or not documents:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
|
||||||
|
except orjson.JSONDecodeError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
|
# Determine which endpoint serves this model
|
||||||
|
try:
|
||||||
|
endpoint = await choose_endpoint(model)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
# Ollama endpoints have no native rerank support
|
||||||
|
if not is_openai_compatible(endpoint):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=501,
|
||||||
|
detail=(
|
||||||
|
f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
|
||||||
|
"reranking. Use a llama-server or OpenAI-compatible endpoint with a "
|
||||||
|
"dedicated reranker model."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize model name for tracking
|
||||||
|
tracking_model = get_tracking_model(endpoint, model)
|
||||||
|
if ":latest" in model:
|
||||||
|
model = model.split(":latest")[0]
|
||||||
|
|
||||||
|
await increment_usage(endpoint, tracking_model)
|
||||||
|
|
||||||
|
# Build upstream rerank request body – forward only recognised fields
|
||||||
|
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
|
||||||
|
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
|
||||||
|
if optional_key in payload:
|
||||||
|
upstream_payload[optional_key] = payload[optional_key]
|
||||||
|
|
||||||
|
# Determine upstream URL:
|
||||||
|
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
|
||||||
|
# External OpenAI endpoints expose /rerank under their /v1 base
|
||||||
|
if endpoint in config.llama_server_endpoints:
|
||||||
|
# llama-server: endpoint may or may not already contain /v1
|
||||||
|
if "/v1" in endpoint:
|
||||||
|
rerank_url = f"{endpoint}/rerank"
|
||||||
|
else:
|
||||||
|
rerank_url = f"{endpoint}/v1/rerank"
|
||||||
|
else:
|
||||||
|
# External OpenAI-compatible: ep2base gives us the /v1 base
|
||||||
|
rerank_url = f"{ep2base(endpoint)}/rerank"
|
||||||
|
|
||||||
|
api_key = config.api_keys.get(endpoint, "no-key")
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
}
|
||||||
|
|
||||||
|
client: aiohttp.ClientSession = app_state["session"]
|
||||||
|
try:
|
||||||
|
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
|
||||||
|
response_bytes = await resp.read()
|
||||||
|
if resp.status >= 400:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=resp.status,
|
||||||
|
detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
|
||||||
|
)
|
||||||
|
data = orjson.loads(response_bytes)
|
||||||
|
|
||||||
|
# Record token usage if the upstream returned a usage object
|
||||||
|
usage = data.get("usage") or {}
|
||||||
|
prompt_tok = usage.get("prompt_tokens") or 0
|
||||||
|
total_tok = usage.get("total_tokens") or 0
|
||||||
|
# For reranking there are no completion tokens; we record prompt tokens only
|
||||||
|
if prompt_tok or total_tok:
|
||||||
|
await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
|
||||||
|
|
||||||
|
return JSONResponse(content=data)
|
||||||
|
finally:
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
|
|
||||||
|
# -------------------------------------------------------------
|
||||||
|
# 26. Serve the static front‑end
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue