Merge pull request #30 from nomyo-ai/dev-v0.7.x

- improved performance
- added /v1/rerank endpoint
- refactor of choose_endpoints for atomic upgrade of usage counters
- fixes for security, type- and keyerrors
- improved database handling
This commit is contained in:
Alpha Nerd 2026-03-04 11:01:22 +01:00 committed by GitHub
commit e51969a2bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 268 additions and 138 deletions

View file

@ -14,12 +14,17 @@ Copy/Clone the repository, edit the config.yaml by adding your Ollama backend se
```
# config.yaml
# Ollama or OpenAI API V1 endpoints
endpoints:
- http://ollama0:11434
- http://ollama1:11434
- http://ollama2:11434
- https://api.openai.com/v1
# llama.cpp server endpoints
llama_server_endpoints:
- http://192.168.0.33:8889/v1
# Maximum concurrent connections *per endpointmodel pair*
max_concurrent_connections: 2
@ -34,6 +39,7 @@ api_keys:
"http://192.168.0.51:11434": "ollama"
"http://192.168.0.52:11434": "ollama"
"https://api.openai.com/v1": "${OPENAI_KEY}"
"http://192.168.0.33:8889/v1": "llama"
```
Run the NOMYO Router in a dedicated virtual environment, install the requirements and run with uvicorn:
@ -58,6 +64,12 @@ finally you can
uvicorn router:app --host 127.0.0.1 --port 12434
```
in <u>very</u> high concurrent scenarios (> 500 simultaneous requests) you can also run with uvloop
```
uvicorn router:app --host 127.0.0.1 --port 12434 --loop uvloop
```
## Docker Deployment
Build the container image locally:
@ -98,7 +110,6 @@ This way the Ollama backend servers are utilized more efficient than by simply u
NOMYO Router also supports OpenAI API compatible v1 backend servers.
## Supplying the router API key
If you set `nomyo-router-api-key` in `config.yaml` (or `NOMYO_ROUTER_API_KEY` env), every request to NOMYO Router must include the key:
@ -107,6 +118,7 @@ If you set `nomyo-router-api-key` in `config.yaml` (or `NOMYO_ROUTER_API_KEY` en
- Query param (fallback): `?api_key=<router_key>`
Examples:
```bash
curl -H "Authorization: Bearer $NOMYO_ROUTER_API_KEY" http://localhost:12434/api/tags
curl "http://localhost:12434/api/tags?api_key=$NOMYO_ROUTER_API_KEY"

41
db.py
View file

@ -63,6 +63,7 @@ class TokenDatabase:
)
''')
await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_timestamp ON token_time_series(timestamp)')
await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_model_ts ON token_time_series(model, timestamp)')
await db.commit()
async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int):
@ -178,6 +179,46 @@ class TokenDatabase:
'timestamp': row[5]
}
async def get_time_series_for_model(self, model: str, limit: int = 50000):
"""Get time series entries for a specific model, newest first.
Uses the (model, timestamp) composite index so the DB does the filtering
instead of returning all rows and discarding non-matching ones in Python.
"""
db = await self._get_connection()
async with self._operation_lock:
async with db.execute('''
SELECT endpoint, input_tokens, output_tokens, total_tokens, timestamp
FROM token_time_series
WHERE model = ?
ORDER BY timestamp DESC
LIMIT ?
''', (model, limit)) as cursor:
async for row in cursor:
yield {
'endpoint': row[0],
'input_tokens': row[1],
'output_tokens': row[2],
'total_tokens': row[3],
'timestamp': row[4],
}
async def get_endpoint_distribution_for_model(self, model: str) -> dict:
"""Return total tokens per endpoint for a specific model as a plain dict.
Computed entirely in SQL so no Python-side aggregation is needed.
"""
db = await self._get_connection()
async with self._operation_lock:
async with db.execute('''
SELECT endpoint, SUM(total_tokens)
FROM token_time_series
WHERE model = ?
GROUP BY endpoint
''', (model,)) as cursor:
rows = await cursor.fetchall()
return {row[0]: row[1] for row in rows}
async def get_token_counts_for_model(self, model):
"""Get token counts for a specific model, aggregated across all endpoints."""
db = await self._get_connection()

View file

@ -36,5 +36,6 @@ tqdm==4.67.1
typing-inspection==0.4.1
typing_extensions==4.14.1
uvicorn==0.38.0
uvloop
yarl==1.20.1
aiosqlite

350
router.py
View file

@ -2,11 +2,11 @@
title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing
author: alpha-nerd-nomyo
author_url: https://github.com/nomyo-ai
version: 0.6
version: 0.7
license: AGPL
"""
# -------------------------------------------------------------
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math
try:
import truststore; truststore.inject_into_ssl()
except ImportError:
@ -75,7 +75,7 @@ def _mask_secrets(text: str) -> str:
return text
# OpenAI-style keys (sk-...) and generic "api key" mentions
text = re.sub(r"sk-[A-Za-z0-9]{4}[A-Za-z0-9_-]*", "sk-***redacted***", text)
text = re.sub(r"(?i)(api[-_ ]key\\s*[:=]\\s*)([^\\s]+)", r"\\1***redacted***", text)
text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text)
return text
# ------------------------------------------------------------------
@ -374,8 +374,11 @@ def _extract_llama_quant(name: str) -> str:
def _is_llama_model_loaded(item: dict) -> bool:
"""Return True if a llama-server /v1/models item has status 'loaded'.
Handles both dict format ({"value": "loaded"}) and plain string ("loaded")."""
Handles both dict format ({"value": "loaded"}) and plain string ("loaded").
If no status field is present, the model is always-loaded (not dynamically managed)."""
status = item.get("status")
if status is None:
return True # No status field: model is always loaded (e.g. single-model servers)
if isinstance(status, dict):
return status.get("value") == "loaded"
if isinstance(status, str):
@ -925,11 +928,12 @@ async def decrement_usage(endpoint: str, model: str) -> None:
# usage_counts.pop(endpoint, None)
await publish_snapshot()
async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
Helper function to make a chat request to a specific endpoint.
Handles endpoint selection, client creation, usage tracking, and request execution.
"""
endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
@ -959,10 +963,6 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
else:
client = ollama.AsyncClient(host=endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
try:
if use_openai:
start_ts = time.perf_counter()
@ -1054,18 +1054,11 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
moe_reqs = []
# Generate 3 responses
response1_endpoint = await choose_endpoint(model)
response1_task = asyncio.create_task(_make_chat_request(response1_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
response2_endpoint = await choose_endpoint(model)
response2_task = asyncio.create_task(_make_chat_request(response2_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
response3_endpoint = await choose_endpoint(model)
response3_task = asyncio.create_task(_make_chat_request(response3_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
# Generate 3 responses — choose_endpoint is called inside _make_chat_request and
# atomically reserves a slot, so all 3 tasks see each other's load immediately.
response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
responses = await asyncio.gather(response1_task, response2_task, response3_task)
@ -1074,17 +1067,9 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
moe_reqs.append(moe_req)
# Generate 3 critiques
critique1_endpoint = await choose_endpoint(model)
critique1_task = asyncio.create_task(_make_chat_request(critique1_endpoint, model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
critique2_endpoint = await choose_endpoint(model)
critique2_task = asyncio.create_task(_make_chat_request(critique2_endpoint, model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
critique3_endpoint = await choose_endpoint(model)
critique3_task = asyncio.create_task(_make_chat_request(critique3_endpoint, model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
@ -1092,8 +1077,7 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
m = enhance.moe_select_candidate(query, critiques)
# Generate final response
final_endpoint = await choose_endpoint(model)
return await _make_chat_request(final_endpoint, model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
def iso8601_ns():
ns = time.time_ns()
@ -1462,7 +1446,7 @@ async def get_usage_counts() -> Dict:
# -------------------------------------------------------------
# 5. Endpoint selection logic (respecting the configurable limit)
# -------------------------------------------------------------
async def choose_endpoint(model: str) -> str:
async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
"""
Determine which endpoint to use for the given model while respecting
the `max_concurrent_connections` per endpointmodel pair **and**
@ -1523,7 +1507,8 @@ async def choose_endpoint(model: str) -> str:
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks)
# Protect all reads of usage_counts with the lock
# Protect all reads/writes of usage_counts with the lock so that selection
# and reservation are atomic — concurrent callers see each other's pending load.
async with usage_lock:
# Helper: current usage for (endpoint, model) using the same normalized key
# that increment_usage/decrement_usage store — raw model names differ from
@ -1541,34 +1526,37 @@ async def choose_endpoint(model: str) -> str:
# Sort ascending for load balancing — all endpoints here already have the
# model loaded, so there is no model-switching cost to optimise for.
loaded_and_free.sort(key=tracking_usage)
# When all candidates are equally idle, randomise to avoid always picking
# the first entry in a stable sort.
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
return random.choice(loaded_and_free)
selected = random.choice(loaded_and_free)
else:
selected = loaded_and_free[0]
else:
# 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [
ep for ep in candidate_endpoints
if tracking_usage(ep) < config.max_concurrent_connections
]
return loaded_and_free[0]
if endpoints_with_free_slot:
# Sort by total endpoint load (ascending) to prefer idle endpoints.
endpoints_with_free_slot.sort(
key=lambda ep: sum(usage_counts.get(ep, {}).values())
)
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
selected = random.choice(endpoints_with_free_slot)
else:
selected = endpoints_with_free_slot[0]
else:
# 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue)
selected = min(candidate_endpoints, key=tracking_usage)
# 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [
ep for ep in candidate_endpoints
if tracking_usage(ep) < config.max_concurrent_connections
]
if endpoints_with_free_slot:
# Sort by total endpoint load (ascending) to prefer idle endpoints.
endpoints_with_free_slot.sort(
key=lambda ep: sum(usage_counts.get(ep, {}).values())
)
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
return random.choice(endpoints_with_free_slot)
return endpoints_with_free_slot[0]
# 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue)
ep = min(candidate_endpoints, key=tracking_usage)
return ep
tracking_model = get_tracking_model(selected, model)
if reserve:
usage_counts[selected][tracking_model] += 1
await publish_snapshot()
return selected, tracking_model
# -------------------------------------------------------------
# 6. API route Generate
@ -1609,10 +1597,8 @@ async def proxy(request: Request):
raise HTTPException(status_code=400, detail=error_msg) from e
endpoint = await choose_endpoint(model)
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1637,7 +1623,6 @@ async def proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, tracking_model)
# 4. Async generator that streams data and decrements the counter
async def stream_generate_response():
@ -1732,10 +1717,8 @@ async def chat_proxy(request: Request):
opt = True
else:
opt = False
endpoint = await choose_endpoint(model)
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1766,7 +1749,6 @@ async def chat_proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams chat data and decrements the counter
async def stream_chat_response():
try:
@ -1858,10 +1840,8 @@ async def embedding_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1869,7 +1849,6 @@ async def embedding_proxy(request: Request):
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams embedding data and decrements the counter
async def stream_embedding_response():
try:
@ -1926,10 +1905,8 @@ async def embed_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1937,7 +1914,6 @@ async def embed_proxy(request: Request):
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams embed data and decrements the counter
async def stream_embedding_response():
try:
@ -2035,8 +2011,7 @@ async def show_proxy(request: Request, model: Optional[str] = None):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
#await increment_usage(endpoint, model)
endpoint, _ = await choose_endpoint(model, reserve=False)
client = ollama.AsyncClient(host=endpoint)
@ -2111,22 +2086,10 @@ async def stats_proxy(request: Request, model: Optional[str] = None):
status_code=404, detail="No token data found for this model"
)
# Get time series data for the last 30 days (43200 minutes = 30 days)
# Assuming entries are grouped by minute, 30 days = 43200 entries max
time_series = []
endpoint_totals = defaultdict(int) # Track tokens per endpoint
async for entry in db.get_latest_time_series(limit=50000):
if entry['model'] == model:
time_series.append({
'endpoint': entry['endpoint'],
'timestamp': entry['timestamp'],
'input_tokens': entry['input_tokens'],
'output_tokens': entry['output_tokens'],
'total_tokens': entry['total_tokens']
})
# Accumulate total tokens per endpoint
endpoint_totals[entry['endpoint']] += entry['total_tokens']
time_series = [
entry async for entry in db.get_time_series_for_model(model)
]
endpoint_distribution = await db.get_endpoint_distribution_for_model(model)
return {
'model': model,
@ -2134,7 +2097,7 @@ async def stats_proxy(request: Request, model: Optional[str] = None):
'output_tokens': token_data['output_tokens'],
'total_tokens': token_data['total_tokens'],
'time_series': time_series,
'endpoint_distribution': dict(endpoint_totals)
'endpoint_distribution': endpoint_distribution,
}
# -------------------------------------------------------------
@ -2418,8 +2381,10 @@ async def ps_proxy(request: Request):
})
# 3. Return a JSONResponse with deduplicated currently deployed models
# Deduplicate on 'name' rather than 'digest': llama-server models always
# have digest="" so deduping on digest collapses all of them to one entry.
return JSONResponse(
content={"models": dedupe_on_keys(models['models'], ['digest'])},
content={"models": dedupe_on_keys(models['models'], ['name'])},
status_code=200,
)
@ -2565,7 +2530,7 @@ async def config_proxy(request: Request):
client: aiohttp.ClientSession = app_state["session"]
headers = None
if "/v1" in url:
headers = {"Authorization": "Bearer " + config.api_keys[url]}
headers = {"Authorization": "Bearer " + config.api_keys.get(url, "no-key")}
target_url = f"{url}/models"
else:
target_url = f"{url}/api/version"
@ -2625,10 +2590,7 @@ async def openai_embedding_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
endpoint, tracking_model = await choose_endpoint(model)
if is_openai_compatible(endpoint):
api_key = config.api_keys.get(endpoint, "no-key")
else:
@ -2637,13 +2599,16 @@ async def openai_embedding_proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key)
# 3. Async generator that streams embedding data and decrements the counter
async_gen = await oclient.embeddings.create(input=doc, model=model)
await decrement_usage(endpoint, tracking_model)
# 5. Return a StreamingResponse backed by the generator
return async_gen
try:
async_gen = await oclient.embeddings.create(input=doc, model=model)
result = async_gen.model_dump()
for item in result.get("data", []):
emb = item.get("embedding")
if emb:
item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb]
return JSONResponse(content=result)
finally:
await decrement_usage(endpoint, tracking_model)
# -------------------------------------------------------------
# 22. API route OpenAI compatible Chat Completions
@ -2676,12 +2641,21 @@ async def openai_chat_completions_proxy(request: Request):
logprobs = payload.get("logprobs")
top_logprobs = payload.get("top_logprobs")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not isinstance(messages, list):
raise HTTPException(
status_code=400, detail="Missing required field 'messages' (must be a list)"
)
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
"messages": messages,
"messages": messages,
"model": model,
}
@ -2703,23 +2677,11 @@ async def openai_chat_completions_proxy(request: Request):
}
params.update({k: v for k, v in optional_params.items() if v is not None})
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not isinstance(messages, list):
raise HTTPException(
status_code=400, detail="Missing required field 'messages' (must be a list)"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
endpoint, tracking_model = await choose_endpoint(model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Async generator that streams completions data and decrements the counter
@ -2825,12 +2787,21 @@ async def openai_completions_proxy(request: Request):
max_completion_tokens = payload.get("max_completion_tokens")
suffix = payload.get("suffix")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not prompt:
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
"prompt": prompt,
"prompt": prompt,
"model": model,
}
@ -2849,23 +2820,11 @@ async def openai_completions_proxy(request: Request):
}
params.update({k: v for k, v in optional_params.items() if v is not None})
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not prompt:
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
endpoint, tracking_model = await choose_endpoint(model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
@ -3001,7 +2960,124 @@ async def openai_models_proxy(request: Request):
)
# -------------------------------------------------------------
# 25. Serve the static frontend
# 25. API route OpenAI/Jina/Cohere compatible Rerank
# -------------------------------------------------------------
@app.post("/v1/rerank")
@app.post("/rerank")
async def rerank_proxy(request: Request):
"""
Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
Compatible with the Jina/Cohere rerank API convention used by llama-server,
vLLM, and services such as Cohere and Jina AI.
Ollama does not natively support reranking; requests routed to a plain Ollama
endpoint will receive a 501 Not Implemented response.
Request body:
model (str, required) reranker model name
query (str, required) search query
documents (list[str], required) candidate documents to rank
top_n (int, optional) limit returned results (default: all)
return_documents (bool, optional) include document text in results
max_tokens_per_doc (int, optional) truncation limit per document
Response (Jina/Cohere-compatible):
{
"id": "...",
"model": "...",
"usage": {"prompt_tokens": N, "total_tokens": N},
"results": [{"index": 0, "relevance_score": 0.95}, ...]
}
"""
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
query = payload.get("query")
documents = payload.get("documents")
if not model:
raise HTTPException(status_code=400, detail="Missing required field 'model'")
if not query:
raise HTTPException(status_code=400, detail="Missing required field 'query'")
if not isinstance(documents, list) or not documents:
raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Determine which endpoint serves this model
try:
endpoint, tracking_model = await choose_endpoint(model)
except RuntimeError as e:
raise HTTPException(status_code=404, detail=str(e))
# Ollama endpoints have no native rerank support
if not is_openai_compatible(endpoint):
await decrement_usage(endpoint, tracking_model)
raise HTTPException(
status_code=501,
detail=(
f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
"reranking. Use a llama-server or OpenAI-compatible endpoint with a "
"dedicated reranker model."
),
)
if ":latest" in model:
model = model.split(":latest")[0]
# Build upstream rerank request body forward only recognised fields
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
if optional_key in payload:
upstream_payload[optional_key] = payload[optional_key]
# Determine upstream URL:
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
# External OpenAI endpoints expose /rerank under their /v1 base
if endpoint in config.llama_server_endpoints:
# llama-server: endpoint may or may not already contain /v1
if "/v1" in endpoint:
rerank_url = f"{endpoint}/rerank"
else:
rerank_url = f"{endpoint}/v1/rerank"
else:
# External OpenAI-compatible: ep2base gives us the /v1 base
rerank_url = f"{ep2base(endpoint)}/rerank"
api_key = config.api_keys.get(endpoint, "no-key")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
client: aiohttp.ClientSession = app_state["session"]
try:
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
response_bytes = await resp.read()
if resp.status >= 400:
raise HTTPException(
status_code=resp.status,
detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
)
data = orjson.loads(response_bytes)
# Record token usage if the upstream returned a usage object
usage = data.get("usage") or {}
prompt_tok = usage.get("prompt_tokens") or 0
total_tok = usage.get("total_tokens") or 0
# For reranking there are no completion tokens; we record prompt tokens only
if prompt_tok or total_tok:
await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
return JSONResponse(content=data)
finally:
await decrement_usage(endpoint, tracking_model)
# -------------------------------------------------------------
# 26. Serve the static frontend
# -------------------------------------------------------------
app.mount("/static", StaticFiles(directory="static"), name="static")
@ -3095,7 +3171,7 @@ async def usage_stream(request: Request):
# -------------------------------------------------------------
@app.on_event("startup")
async def startup_event() -> None:
global config, db
global config, db, token_worker_task, flush_task
# Load YAML config (or use defaults if not present)
config_path = _config_path_from_env()
config = Config.from_yaml(config_path)