diff --git a/README.md b/README.md
index d7e08c1..1780c40 100644
--- a/README.md
+++ b/README.md
@@ -14,12 +14,17 @@ Copy/Clone the repository, edit the config.yaml by adding your Ollama backend se
```
# config.yaml
+# Ollama or OpenAI API V1 endpoints
endpoints:
- http://ollama0:11434
- http://ollama1:11434
- http://ollama2:11434
- https://api.openai.com/v1
+# llama.cpp server endpoints
+llama_server_endpoints:
+ - http://192.168.0.33:8889/v1
+
# Maximum concurrent connections *per endpoint‑model pair*
max_concurrent_connections: 2
@@ -34,6 +39,7 @@ api_keys:
"http://192.168.0.51:11434": "ollama"
"http://192.168.0.52:11434": "ollama"
"https://api.openai.com/v1": "${OPENAI_KEY}"
+ "http://192.168.0.33:8889/v1": "llama"
```
Run the NOMYO Router in a dedicated virtual environment, install the requirements and run with uvicorn:
@@ -58,6 +64,12 @@ finally you can
uvicorn router:app --host 127.0.0.1 --port 12434
```
+in very high concurrent scenarios (> 500 simultaneous requests) you can also run with uvloop
+
+```
+uvicorn router:app --host 127.0.0.1 --port 12434 --loop uvloop
+```
+
## Docker Deployment
Build the container image locally:
@@ -98,7 +110,6 @@ This way the Ollama backend servers are utilized more efficient than by simply u
NOMYO Router also supports OpenAI API compatible v1 backend servers.
-
## Supplying the router API key
If you set `nomyo-router-api-key` in `config.yaml` (or `NOMYO_ROUTER_API_KEY` env), every request to NOMYO Router must include the key:
@@ -107,6 +118,7 @@ If you set `nomyo-router-api-key` in `config.yaml` (or `NOMYO_ROUTER_API_KEY` en
- Query param (fallback): `?api_key=`
Examples:
+
```bash
curl -H "Authorization: Bearer $NOMYO_ROUTER_API_KEY" http://localhost:12434/api/tags
curl "http://localhost:12434/api/tags?api_key=$NOMYO_ROUTER_API_KEY"
diff --git a/db.py b/db.py
index 11df49c..af7b252 100644
--- a/db.py
+++ b/db.py
@@ -63,6 +63,7 @@ class TokenDatabase:
)
''')
await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_timestamp ON token_time_series(timestamp)')
+ await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_model_ts ON token_time_series(model, timestamp)')
await db.commit()
async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int):
@@ -178,6 +179,46 @@ class TokenDatabase:
'timestamp': row[5]
}
+ async def get_time_series_for_model(self, model: str, limit: int = 50000):
+ """Get time series entries for a specific model, newest first.
+
+ Uses the (model, timestamp) composite index so the DB does the filtering
+ instead of returning all rows and discarding non-matching ones in Python.
+ """
+ db = await self._get_connection()
+ async with self._operation_lock:
+ async with db.execute('''
+ SELECT endpoint, input_tokens, output_tokens, total_tokens, timestamp
+ FROM token_time_series
+ WHERE model = ?
+ ORDER BY timestamp DESC
+ LIMIT ?
+ ''', (model, limit)) as cursor:
+ async for row in cursor:
+ yield {
+ 'endpoint': row[0],
+ 'input_tokens': row[1],
+ 'output_tokens': row[2],
+ 'total_tokens': row[3],
+ 'timestamp': row[4],
+ }
+
+ async def get_endpoint_distribution_for_model(self, model: str) -> dict:
+ """Return total tokens per endpoint for a specific model as a plain dict.
+
+ Computed entirely in SQL so no Python-side aggregation is needed.
+ """
+ db = await self._get_connection()
+ async with self._operation_lock:
+ async with db.execute('''
+ SELECT endpoint, SUM(total_tokens)
+ FROM token_time_series
+ WHERE model = ?
+ GROUP BY endpoint
+ ''', (model,)) as cursor:
+ rows = await cursor.fetchall()
+ return {row[0]: row[1] for row in rows}
+
async def get_token_counts_for_model(self, model):
"""Get token counts for a specific model, aggregated across all endpoints."""
db = await self._get_connection()
diff --git a/requirements.txt b/requirements.txt
index aa51a0f..f1db419 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -36,5 +36,6 @@ tqdm==4.67.1
typing-inspection==0.4.1
typing_extensions==4.14.1
uvicorn==0.38.0
+uvloop
yarl==1.20.1
aiosqlite
diff --git a/router.py b/router.py
index 1ef096a..691d11d 100644
--- a/router.py
+++ b/router.py
@@ -2,11 +2,11 @@
title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing
author: alpha-nerd-nomyo
author_url: https://github.com/nomyo-ai
-version: 0.6
+version: 0.7
license: AGPL
"""
# -------------------------------------------------------------
-import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets
+import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math
try:
import truststore; truststore.inject_into_ssl()
except ImportError:
@@ -75,7 +75,7 @@ def _mask_secrets(text: str) -> str:
return text
# OpenAI-style keys (sk-...) and generic "api key" mentions
text = re.sub(r"sk-[A-Za-z0-9]{4}[A-Za-z0-9_-]*", "sk-***redacted***", text)
- text = re.sub(r"(?i)(api[-_ ]key\\s*[:=]\\s*)([^\\s]+)", r"\\1***redacted***", text)
+ text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text)
return text
# ------------------------------------------------------------------
@@ -374,8 +374,11 @@ def _extract_llama_quant(name: str) -> str:
def _is_llama_model_loaded(item: dict) -> bool:
"""Return True if a llama-server /v1/models item has status 'loaded'.
- Handles both dict format ({"value": "loaded"}) and plain string ("loaded")."""
+ Handles both dict format ({"value": "loaded"}) and plain string ("loaded").
+ If no status field is present, the model is always-loaded (not dynamically managed)."""
status = item.get("status")
+ if status is None:
+ return True # No status field: model is always loaded (e.g. single-model servers)
if isinstance(status, dict):
return status.get("value") == "loaded"
if isinstance(status, str):
@@ -925,11 +928,12 @@ async def decrement_usage(endpoint: str, model: str) -> None:
# usage_counts.pop(endpoint, None)
await publish_snapshot()
-async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
+async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
Helper function to make a chat request to a specific endpoint.
Handles endpoint selection, client creation, usage tracking, and request execution.
"""
+ endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
@@ -959,10 +963,6 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
else:
client = ollama.AsyncClient(host=endpoint)
- # Normalize model name for tracking so it matches the PS table key
- tracking_model = get_tracking_model(endpoint, model)
- await increment_usage(endpoint, tracking_model)
-
try:
if use_openai:
start_ts = time.perf_counter()
@@ -1054,18 +1054,11 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
moe_reqs = []
- # Generate 3 responses
- response1_endpoint = await choose_endpoint(model)
- response1_task = asyncio.create_task(_make_chat_request(response1_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
- await asyncio.sleep(0.01) # Small delay to allow usage count to update
-
- response2_endpoint = await choose_endpoint(model)
- response2_task = asyncio.create_task(_make_chat_request(response2_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
- await asyncio.sleep(0.01) # Small delay to allow usage count to update
-
- response3_endpoint = await choose_endpoint(model)
- response3_task = asyncio.create_task(_make_chat_request(response3_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
- await asyncio.sleep(0.01) # Small delay to allow usage count to update
+ # Generate 3 responses — choose_endpoint is called inside _make_chat_request and
+ # atomically reserves a slot, so all 3 tasks see each other's load immediately.
+ response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
+ response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
+ response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
responses = await asyncio.gather(response1_task, response2_task, response3_task)
@@ -1074,17 +1067,9 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
moe_reqs.append(moe_req)
# Generate 3 critiques
- critique1_endpoint = await choose_endpoint(model)
- critique1_task = asyncio.create_task(_make_chat_request(critique1_endpoint, model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
- await asyncio.sleep(0.01) # Small delay to allow usage count to update
-
- critique2_endpoint = await choose_endpoint(model)
- critique2_task = asyncio.create_task(_make_chat_request(critique2_endpoint, model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
- await asyncio.sleep(0.01) # Small delay to allow usage count to update
-
- critique3_endpoint = await choose_endpoint(model)
- critique3_task = asyncio.create_task(_make_chat_request(critique3_endpoint, model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
- await asyncio.sleep(0.01) # Small delay to allow usage count to update
+ critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
+ critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
+ critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
@@ -1092,8 +1077,7 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
m = enhance.moe_select_candidate(query, critiques)
# Generate final response
- final_endpoint = await choose_endpoint(model)
- return await _make_chat_request(final_endpoint, model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
+ return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
def iso8601_ns():
ns = time.time_ns()
@@ -1462,7 +1446,7 @@ async def get_usage_counts() -> Dict:
# -------------------------------------------------------------
# 5. Endpoint selection logic (respecting the configurable limit)
# -------------------------------------------------------------
-async def choose_endpoint(model: str) -> str:
+async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
"""
Determine which endpoint to use for the given model while respecting
the `max_concurrent_connections` per endpoint‑model pair **and**
@@ -1523,7 +1507,8 @@ async def choose_endpoint(model: str) -> str:
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks)
- # Protect all reads of usage_counts with the lock
+ # Protect all reads/writes of usage_counts with the lock so that selection
+ # and reservation are atomic — concurrent callers see each other's pending load.
async with usage_lock:
# Helper: current usage for (endpoint, model) using the same normalized key
# that increment_usage/decrement_usage store — raw model names differ from
@@ -1541,34 +1526,37 @@ async def choose_endpoint(model: str) -> str:
# Sort ascending for load balancing — all endpoints here already have the
# model loaded, so there is no model-switching cost to optimise for.
loaded_and_free.sort(key=tracking_usage)
-
# When all candidates are equally idle, randomise to avoid always picking
# the first entry in a stable sort.
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
- return random.choice(loaded_and_free)
+ selected = random.choice(loaded_and_free)
+ else:
+ selected = loaded_and_free[0]
+ else:
+ # 4️⃣ Endpoints among the candidates that simply have a free slot
+ endpoints_with_free_slot = [
+ ep for ep in candidate_endpoints
+ if tracking_usage(ep) < config.max_concurrent_connections
+ ]
- return loaded_and_free[0]
+ if endpoints_with_free_slot:
+ # Sort by total endpoint load (ascending) to prefer idle endpoints.
+ endpoints_with_free_slot.sort(
+ key=lambda ep: sum(usage_counts.get(ep, {}).values())
+ )
+ if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
+ selected = random.choice(endpoints_with_free_slot)
+ else:
+ selected = endpoints_with_free_slot[0]
+ else:
+ # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
+ selected = min(candidate_endpoints, key=tracking_usage)
- # 4️⃣ Endpoints among the candidates that simply have a free slot
- endpoints_with_free_slot = [
- ep for ep in candidate_endpoints
- if tracking_usage(ep) < config.max_concurrent_connections
- ]
-
- if endpoints_with_free_slot:
- # Sort by total endpoint load (ascending) to prefer idle endpoints.
- endpoints_with_free_slot.sort(
- key=lambda ep: sum(usage_counts.get(ep, {}).values())
- )
-
- if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
- return random.choice(endpoints_with_free_slot)
-
- return endpoints_with_free_slot[0]
-
- # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
- ep = min(candidate_endpoints, key=tracking_usage)
- return ep
+ tracking_model = get_tracking_model(selected, model)
+ if reserve:
+ usage_counts[selected][tracking_model] += 1
+ await publish_snapshot()
+ return selected, tracking_model
# -------------------------------------------------------------
# 6. API route – Generate
@@ -1609,10 +1597,8 @@ async def proxy(request: Request):
raise HTTPException(status_code=400, detail=error_msg) from e
- endpoint = await choose_endpoint(model)
+ endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
- # Normalize model name for tracking so it matches the PS table key
- tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@@ -1637,7 +1623,6 @@ async def proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
- await increment_usage(endpoint, tracking_model)
# 4. Async generator that streams data and decrements the counter
async def stream_generate_response():
@@ -1732,10 +1717,8 @@ async def chat_proxy(request: Request):
opt = True
else:
opt = False
- endpoint = await choose_endpoint(model)
+ endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
- # Normalize model name for tracking so it matches the PS table key
- tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@@ -1766,7 +1749,6 @@ async def chat_proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
- await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams chat data and decrements the counter
async def stream_chat_response():
try:
@@ -1858,10 +1840,8 @@ async def embedding_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
- endpoint = await choose_endpoint(model)
+ endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
- # Normalize model name for tracking so it matches the PS table key
- tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@@ -1869,7 +1849,6 @@ async def embedding_proxy(request: Request):
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
- await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams embedding data and decrements the counter
async def stream_embedding_response():
try:
@@ -1926,10 +1905,8 @@ async def embed_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
- endpoint = await choose_endpoint(model)
+ endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
- # Normalize model name for tracking so it matches the PS table key
- tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@@ -1937,7 +1914,6 @@ async def embed_proxy(request: Request):
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
- await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams embed data and decrements the counter
async def stream_embedding_response():
try:
@@ -2035,8 +2011,7 @@ async def show_proxy(request: Request, model: Optional[str] = None):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
- endpoint = await choose_endpoint(model)
- #await increment_usage(endpoint, model)
+ endpoint, _ = await choose_endpoint(model, reserve=False)
client = ollama.AsyncClient(host=endpoint)
@@ -2111,22 +2086,10 @@ async def stats_proxy(request: Request, model: Optional[str] = None):
status_code=404, detail="No token data found for this model"
)
- # Get time series data for the last 30 days (43200 minutes = 30 days)
- # Assuming entries are grouped by minute, 30 days = 43200 entries max
- time_series = []
- endpoint_totals = defaultdict(int) # Track tokens per endpoint
-
- async for entry in db.get_latest_time_series(limit=50000):
- if entry['model'] == model:
- time_series.append({
- 'endpoint': entry['endpoint'],
- 'timestamp': entry['timestamp'],
- 'input_tokens': entry['input_tokens'],
- 'output_tokens': entry['output_tokens'],
- 'total_tokens': entry['total_tokens']
- })
- # Accumulate total tokens per endpoint
- endpoint_totals[entry['endpoint']] += entry['total_tokens']
+ time_series = [
+ entry async for entry in db.get_time_series_for_model(model)
+ ]
+ endpoint_distribution = await db.get_endpoint_distribution_for_model(model)
return {
'model': model,
@@ -2134,7 +2097,7 @@ async def stats_proxy(request: Request, model: Optional[str] = None):
'output_tokens': token_data['output_tokens'],
'total_tokens': token_data['total_tokens'],
'time_series': time_series,
- 'endpoint_distribution': dict(endpoint_totals)
+ 'endpoint_distribution': endpoint_distribution,
}
# -------------------------------------------------------------
@@ -2418,8 +2381,10 @@ async def ps_proxy(request: Request):
})
# 3. Return a JSONResponse with deduplicated currently deployed models
+ # Deduplicate on 'name' rather than 'digest': llama-server models always
+ # have digest="" so deduping on digest collapses all of them to one entry.
return JSONResponse(
- content={"models": dedupe_on_keys(models['models'], ['digest'])},
+ content={"models": dedupe_on_keys(models['models'], ['name'])},
status_code=200,
)
@@ -2565,7 +2530,7 @@ async def config_proxy(request: Request):
client: aiohttp.ClientSession = app_state["session"]
headers = None
if "/v1" in url:
- headers = {"Authorization": "Bearer " + config.api_keys[url]}
+ headers = {"Authorization": "Bearer " + config.api_keys.get(url, "no-key")}
target_url = f"{url}/models"
else:
target_url = f"{url}/api/version"
@@ -2625,10 +2590,7 @@ async def openai_embedding_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
- endpoint = await choose_endpoint(model)
- # Normalize model name for tracking so it matches the PS table key
- tracking_model = get_tracking_model(endpoint, model)
- await increment_usage(endpoint, tracking_model)
+ endpoint, tracking_model = await choose_endpoint(model)
if is_openai_compatible(endpoint):
api_key = config.api_keys.get(endpoint, "no-key")
else:
@@ -2637,13 +2599,16 @@ async def openai_embedding_proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key)
- # 3. Async generator that streams embedding data and decrements the counter
- async_gen = await oclient.embeddings.create(input=doc, model=model)
-
- await decrement_usage(endpoint, tracking_model)
-
- # 5. Return a StreamingResponse backed by the generator
- return async_gen
+ try:
+ async_gen = await oclient.embeddings.create(input=doc, model=model)
+ result = async_gen.model_dump()
+ for item in result.get("data", []):
+ emb = item.get("embedding")
+ if emb:
+ item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb]
+ return JSONResponse(content=result)
+ finally:
+ await decrement_usage(endpoint, tracking_model)
# -------------------------------------------------------------
# 22. API route – OpenAI compatible Chat Completions
@@ -2676,12 +2641,21 @@ async def openai_chat_completions_proxy(request: Request):
logprobs = payload.get("logprobs")
top_logprobs = payload.get("top_logprobs")
+ if not model:
+ raise HTTPException(
+ status_code=400, detail="Missing required field 'model'"
+ )
+ if not isinstance(messages, list):
+ raise HTTPException(
+ status_code=400, detail="Missing required field 'messages' (must be a list)"
+ )
+
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
- "messages": messages,
+ "messages": messages,
"model": model,
}
@@ -2703,23 +2677,11 @@ async def openai_chat_completions_proxy(request: Request):
}
params.update({k: v for k, v in optional_params.items() if v is not None})
-
- if not model:
- raise HTTPException(
- status_code=400, detail="Missing required field 'model'"
- )
- if not isinstance(messages, list):
- raise HTTPException(
- status_code=400, detail="Missing required field 'messages' (must be a list)"
- )
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
- endpoint = await choose_endpoint(model)
- # Normalize model name for tracking so it matches the PS table key
- tracking_model = get_tracking_model(endpoint, model)
- await increment_usage(endpoint, tracking_model)
+ endpoint, tracking_model = await choose_endpoint(model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Async generator that streams completions data and decrements the counter
@@ -2825,12 +2787,21 @@ async def openai_completions_proxy(request: Request):
max_completion_tokens = payload.get("max_completion_tokens")
suffix = payload.get("suffix")
+ if not model:
+ raise HTTPException(
+ status_code=400, detail="Missing required field 'model'"
+ )
+ if not prompt:
+ raise HTTPException(
+ status_code=400, detail="Missing required field 'prompt'"
+ )
+
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
- "prompt": prompt,
+ "prompt": prompt,
"model": model,
}
@@ -2849,23 +2820,11 @@ async def openai_completions_proxy(request: Request):
}
params.update({k: v for k, v in optional_params.items() if v is not None})
-
- if not model:
- raise HTTPException(
- status_code=400, detail="Missing required field 'model'"
- )
- if not prompt:
- raise HTTPException(
- status_code=400, detail="Missing required field 'prompt'"
- )
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
- endpoint = await choose_endpoint(model)
- # Normalize model name for tracking so it matches the PS table key
- tracking_model = get_tracking_model(endpoint, model)
- await increment_usage(endpoint, tracking_model)
+ endpoint, tracking_model = await choose_endpoint(model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
@@ -3001,7 +2960,124 @@ async def openai_models_proxy(request: Request):
)
# -------------------------------------------------------------
-# 25. Serve the static front‑end
+# 25. API route – OpenAI/Jina/Cohere compatible Rerank
+# -------------------------------------------------------------
+@app.post("/v1/rerank")
+@app.post("/rerank")
+async def rerank_proxy(request: Request):
+ """
+ Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
+
+ Compatible with the Jina/Cohere rerank API convention used by llama-server,
+ vLLM, and services such as Cohere and Jina AI.
+
+ Ollama does not natively support reranking; requests routed to a plain Ollama
+ endpoint will receive a 501 Not Implemented response.
+
+ Request body:
+ model (str, required) – reranker model name
+ query (str, required) – search query
+ documents (list[str], required) – candidate documents to rank
+ top_n (int, optional) – limit returned results (default: all)
+ return_documents (bool, optional) – include document text in results
+ max_tokens_per_doc (int, optional) – truncation limit per document
+
+ Response (Jina/Cohere-compatible):
+ {
+ "id": "...",
+ "model": "...",
+ "usage": {"prompt_tokens": N, "total_tokens": N},
+ "results": [{"index": 0, "relevance_score": 0.95}, ...]
+ }
+ """
+ try:
+ body_bytes = await request.body()
+ payload = orjson.loads(body_bytes.decode("utf-8"))
+
+ model = payload.get("model")
+ query = payload.get("query")
+ documents = payload.get("documents")
+
+ if not model:
+ raise HTTPException(status_code=400, detail="Missing required field 'model'")
+ if not query:
+ raise HTTPException(status_code=400, detail="Missing required field 'query'")
+ if not isinstance(documents, list) or not documents:
+ raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
+ except orjson.JSONDecodeError as e:
+ raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
+
+ # Determine which endpoint serves this model
+ try:
+ endpoint, tracking_model = await choose_endpoint(model)
+ except RuntimeError as e:
+ raise HTTPException(status_code=404, detail=str(e))
+
+ # Ollama endpoints have no native rerank support
+ if not is_openai_compatible(endpoint):
+ await decrement_usage(endpoint, tracking_model)
+ raise HTTPException(
+ status_code=501,
+ detail=(
+ f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
+ "reranking. Use a llama-server or OpenAI-compatible endpoint with a "
+ "dedicated reranker model."
+ ),
+ )
+
+ if ":latest" in model:
+ model = model.split(":latest")[0]
+
+ # Build upstream rerank request body – forward only recognised fields
+ upstream_payload: dict = {"model": model, "query": query, "documents": documents}
+ for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
+ if optional_key in payload:
+ upstream_payload[optional_key] = payload[optional_key]
+
+ # Determine upstream URL:
+ # llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
+ # External OpenAI endpoints expose /rerank under their /v1 base
+ if endpoint in config.llama_server_endpoints:
+ # llama-server: endpoint may or may not already contain /v1
+ if "/v1" in endpoint:
+ rerank_url = f"{endpoint}/rerank"
+ else:
+ rerank_url = f"{endpoint}/v1/rerank"
+ else:
+ # External OpenAI-compatible: ep2base gives us the /v1 base
+ rerank_url = f"{ep2base(endpoint)}/rerank"
+
+ api_key = config.api_keys.get(endpoint, "no-key")
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {api_key}",
+ }
+
+ client: aiohttp.ClientSession = app_state["session"]
+ try:
+ async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
+ response_bytes = await resp.read()
+ if resp.status >= 400:
+ raise HTTPException(
+ status_code=resp.status,
+ detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
+ )
+ data = orjson.loads(response_bytes)
+
+ # Record token usage if the upstream returned a usage object
+ usage = data.get("usage") or {}
+ prompt_tok = usage.get("prompt_tokens") or 0
+ total_tok = usage.get("total_tokens") or 0
+ # For reranking there are no completion tokens; we record prompt tokens only
+ if prompt_tok or total_tok:
+ await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
+
+ return JSONResponse(content=data)
+ finally:
+ await decrement_usage(endpoint, tracking_model)
+
+# -------------------------------------------------------------
+# 26. Serve the static front‑end
# -------------------------------------------------------------
app.mount("/static", StaticFiles(directory="static"), name="static")
@@ -3095,7 +3171,7 @@ async def usage_stream(request: Request):
# -------------------------------------------------------------
@app.on_event("startup")
async def startup_event() -> None:
- global config, db
+ global config, db, token_worker_task, flush_task
# Load YAML config (or use defaults if not present)
config_path = _config_path_from_env()
config = Config.from_yaml(config_path)