feat: adding /v1/rerank endpoint with cohere,jina,llama.cpp compatibility
This commit is contained in:
parent
ad4a1d07b2
commit
cac0580eec
1 changed files with 121 additions and 1 deletions
122
router.py
122
router.py
|
|
@ -3004,7 +3004,127 @@ async def openai_models_proxy(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 25. Serve the static front‑end
|
||||
# 25. API route – OpenAI/Jina/Cohere compatible Rerank
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/v1/rerank")
|
||||
@app.post("/rerank")
|
||||
async def rerank_proxy(request: Request):
|
||||
"""
|
||||
Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
|
||||
|
||||
Compatible with the Jina/Cohere rerank API convention used by llama-server,
|
||||
vLLM, and services such as Cohere and Jina AI.
|
||||
|
||||
Ollama does not natively support reranking; requests routed to a plain Ollama
|
||||
endpoint will receive a 501 Not Implemented response.
|
||||
|
||||
Request body:
|
||||
model (str, required) – reranker model name
|
||||
query (str, required) – search query
|
||||
documents (list[str], required) – candidate documents to rank
|
||||
top_n (int, optional) – limit returned results (default: all)
|
||||
return_documents (bool, optional) – include document text in results
|
||||
max_tokens_per_doc (int, optional) – truncation limit per document
|
||||
|
||||
Response (Jina/Cohere-compatible):
|
||||
{
|
||||
"id": "...",
|
||||
"model": "...",
|
||||
"usage": {"prompt_tokens": N, "total_tokens": N},
|
||||
"results": [{"index": 0, "relevance_score": 0.95}, ...]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
query = payload.get("query")
|
||||
documents = payload.get("documents")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'model'")
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'query'")
|
||||
if not isinstance(documents, list) or not documents:
|
||||
raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# Determine which endpoint serves this model
|
||||
try:
|
||||
endpoint = await choose_endpoint(model)
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
# Ollama endpoints have no native rerank support
|
||||
if not is_openai_compatible(endpoint):
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=(
|
||||
f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
|
||||
"reranking. Use a llama-server or OpenAI-compatible endpoint with a "
|
||||
"dedicated reranker model."
|
||||
),
|
||||
)
|
||||
|
||||
# Normalize model name for tracking
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")[0]
|
||||
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
|
||||
# Build upstream rerank request body – forward only recognised fields
|
||||
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
|
||||
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
|
||||
if optional_key in payload:
|
||||
upstream_payload[optional_key] = payload[optional_key]
|
||||
|
||||
# Determine upstream URL:
|
||||
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
|
||||
# External OpenAI endpoints expose /rerank under their /v1 base
|
||||
if endpoint in config.llama_server_endpoints:
|
||||
# llama-server: endpoint may or may not already contain /v1
|
||||
if "/v1" in endpoint:
|
||||
rerank_url = f"{endpoint}/rerank"
|
||||
else:
|
||||
rerank_url = f"{endpoint}/v1/rerank"
|
||||
else:
|
||||
# External OpenAI-compatible: ep2base gives us the /v1 base
|
||||
rerank_url = f"{ep2base(endpoint)}/rerank"
|
||||
|
||||
api_key = config.api_keys.get(endpoint, "no-key")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
try:
|
||||
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
|
||||
response_bytes = await resp.read()
|
||||
if resp.status >= 400:
|
||||
raise HTTPException(
|
||||
status_code=resp.status,
|
||||
detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
|
||||
)
|
||||
data = orjson.loads(response_bytes)
|
||||
|
||||
# Record token usage if the upstream returned a usage object
|
||||
usage = data.get("usage") or {}
|
||||
prompt_tok = usage.get("prompt_tokens") or 0
|
||||
total_tok = usage.get("total_tokens") or 0
|
||||
# For reranking there are no completion tokens; we record prompt tokens only
|
||||
if prompt_tok or total_tok:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
|
||||
|
||||
return JSONResponse(content=data)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 26. Serve the static front‑end
|
||||
# -------------------------------------------------------------
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue