Merge pull request #30 from nomyo-ai/dev-v0.7.x
- improved performance - added /v1/rerank endpoint - refactor of choose_endpoints for atomic upgrade of usage counters - fixes for security, type- and keyerrors - improved database handling
This commit is contained in:
commit
e51969a2bb
4 changed files with 268 additions and 138 deletions
14
README.md
14
README.md
|
|
@ -14,12 +14,17 @@ Copy/Clone the repository, edit the config.yaml by adding your Ollama backend se
|
|||
|
||||
```
|
||||
# config.yaml
|
||||
# Ollama or OpenAI API V1 endpoints
|
||||
endpoints:
|
||||
- http://ollama0:11434
|
||||
- http://ollama1:11434
|
||||
- http://ollama2:11434
|
||||
- https://api.openai.com/v1
|
||||
|
||||
# llama.cpp server endpoints
|
||||
llama_server_endpoints:
|
||||
- http://192.168.0.33:8889/v1
|
||||
|
||||
# Maximum concurrent connections *per endpoint‑model pair*
|
||||
max_concurrent_connections: 2
|
||||
|
||||
|
|
@ -34,6 +39,7 @@ api_keys:
|
|||
"http://192.168.0.51:11434": "ollama"
|
||||
"http://192.168.0.52:11434": "ollama"
|
||||
"https://api.openai.com/v1": "${OPENAI_KEY}"
|
||||
"http://192.168.0.33:8889/v1": "llama"
|
||||
```
|
||||
|
||||
Run the NOMYO Router in a dedicated virtual environment, install the requirements and run with uvicorn:
|
||||
|
|
@ -58,6 +64,12 @@ finally you can
|
|||
uvicorn router:app --host 127.0.0.1 --port 12434
|
||||
```
|
||||
|
||||
in <u>very</u> high concurrent scenarios (> 500 simultaneous requests) you can also run with uvloop
|
||||
|
||||
```
|
||||
uvicorn router:app --host 127.0.0.1 --port 12434 --loop uvloop
|
||||
```
|
||||
|
||||
## Docker Deployment
|
||||
|
||||
Build the container image locally:
|
||||
|
|
@ -98,7 +110,6 @@ This way the Ollama backend servers are utilized more efficient than by simply u
|
|||
|
||||
NOMYO Router also supports OpenAI API compatible v1 backend servers.
|
||||
|
||||
|
||||
## Supplying the router API key
|
||||
|
||||
If you set `nomyo-router-api-key` in `config.yaml` (or `NOMYO_ROUTER_API_KEY` env), every request to NOMYO Router must include the key:
|
||||
|
|
@ -107,6 +118,7 @@ If you set `nomyo-router-api-key` in `config.yaml` (or `NOMYO_ROUTER_API_KEY` en
|
|||
- Query param (fallback): `?api_key=<router_key>`
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
curl -H "Authorization: Bearer $NOMYO_ROUTER_API_KEY" http://localhost:12434/api/tags
|
||||
curl "http://localhost:12434/api/tags?api_key=$NOMYO_ROUTER_API_KEY"
|
||||
|
|
|
|||
41
db.py
41
db.py
|
|
@ -63,6 +63,7 @@ class TokenDatabase:
|
|||
)
|
||||
''')
|
||||
await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_timestamp ON token_time_series(timestamp)')
|
||||
await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_model_ts ON token_time_series(model, timestamp)')
|
||||
await db.commit()
|
||||
|
||||
async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int):
|
||||
|
|
@ -178,6 +179,46 @@ class TokenDatabase:
|
|||
'timestamp': row[5]
|
||||
}
|
||||
|
||||
async def get_time_series_for_model(self, model: str, limit: int = 50000):
|
||||
"""Get time series entries for a specific model, newest first.
|
||||
|
||||
Uses the (model, timestamp) composite index so the DB does the filtering
|
||||
instead of returning all rows and discarding non-matching ones in Python.
|
||||
"""
|
||||
db = await self._get_connection()
|
||||
async with self._operation_lock:
|
||||
async with db.execute('''
|
||||
SELECT endpoint, input_tokens, output_tokens, total_tokens, timestamp
|
||||
FROM token_time_series
|
||||
WHERE model = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
''', (model, limit)) as cursor:
|
||||
async for row in cursor:
|
||||
yield {
|
||||
'endpoint': row[0],
|
||||
'input_tokens': row[1],
|
||||
'output_tokens': row[2],
|
||||
'total_tokens': row[3],
|
||||
'timestamp': row[4],
|
||||
}
|
||||
|
||||
async def get_endpoint_distribution_for_model(self, model: str) -> dict:
|
||||
"""Return total tokens per endpoint for a specific model as a plain dict.
|
||||
|
||||
Computed entirely in SQL so no Python-side aggregation is needed.
|
||||
"""
|
||||
db = await self._get_connection()
|
||||
async with self._operation_lock:
|
||||
async with db.execute('''
|
||||
SELECT endpoint, SUM(total_tokens)
|
||||
FROM token_time_series
|
||||
WHERE model = ?
|
||||
GROUP BY endpoint
|
||||
''', (model,)) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return {row[0]: row[1] for row in rows}
|
||||
|
||||
async def get_token_counts_for_model(self, model):
|
||||
"""Get token counts for a specific model, aggregated across all endpoints."""
|
||||
db = await self._get_connection()
|
||||
|
|
|
|||
|
|
@ -36,5 +36,6 @@ tqdm==4.67.1
|
|||
typing-inspection==0.4.1
|
||||
typing_extensions==4.14.1
|
||||
uvicorn==0.38.0
|
||||
uvloop
|
||||
yarl==1.20.1
|
||||
aiosqlite
|
||||
|
|
|
|||
350
router.py
350
router.py
|
|
@ -2,11 +2,11 @@
|
|||
title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing
|
||||
author: alpha-nerd-nomyo
|
||||
author_url: https://github.com/nomyo-ai
|
||||
version: 0.6
|
||||
version: 0.7
|
||||
license: AGPL
|
||||
"""
|
||||
# -------------------------------------------------------------
|
||||
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets
|
||||
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math
|
||||
try:
|
||||
import truststore; truststore.inject_into_ssl()
|
||||
except ImportError:
|
||||
|
|
@ -75,7 +75,7 @@ def _mask_secrets(text: str) -> str:
|
|||
return text
|
||||
# OpenAI-style keys (sk-...) and generic "api key" mentions
|
||||
text = re.sub(r"sk-[A-Za-z0-9]{4}[A-Za-z0-9_-]*", "sk-***redacted***", text)
|
||||
text = re.sub(r"(?i)(api[-_ ]key\\s*[:=]\\s*)([^\\s]+)", r"\\1***redacted***", text)
|
||||
text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text)
|
||||
return text
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -374,8 +374,11 @@ def _extract_llama_quant(name: str) -> str:
|
|||
|
||||
def _is_llama_model_loaded(item: dict) -> bool:
|
||||
"""Return True if a llama-server /v1/models item has status 'loaded'.
|
||||
Handles both dict format ({"value": "loaded"}) and plain string ("loaded")."""
|
||||
Handles both dict format ({"value": "loaded"}) and plain string ("loaded").
|
||||
If no status field is present, the model is always-loaded (not dynamically managed)."""
|
||||
status = item.get("status")
|
||||
if status is None:
|
||||
return True # No status field: model is always loaded (e.g. single-model servers)
|
||||
if isinstance(status, dict):
|
||||
return status.get("value") == "loaded"
|
||||
if isinstance(status, str):
|
||||
|
|
@ -925,11 +928,12 @@ async def decrement_usage(endpoint: str, model: str) -> None:
|
|||
# usage_counts.pop(endpoint, None)
|
||||
await publish_snapshot()
|
||||
|
||||
async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
|
||||
async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
|
||||
"""
|
||||
Helper function to make a chat request to a specific endpoint.
|
||||
Handles endpoint selection, client creation, usage tracking, and request execution.
|
||||
"""
|
||||
endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
|
|
@ -959,10 +963,6 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
|||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
|
||||
try:
|
||||
if use_openai:
|
||||
start_ts = time.perf_counter()
|
||||
|
|
@ -1054,18 +1054,11 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
|
|||
|
||||
moe_reqs = []
|
||||
|
||||
# Generate 3 responses
|
||||
response1_endpoint = await choose_endpoint(model)
|
||||
response1_task = asyncio.create_task(_make_chat_request(response1_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
response2_endpoint = await choose_endpoint(model)
|
||||
response2_task = asyncio.create_task(_make_chat_request(response2_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
response3_endpoint = await choose_endpoint(model)
|
||||
response3_task = asyncio.create_task(_make_chat_request(response3_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
# Generate 3 responses — choose_endpoint is called inside _make_chat_request and
|
||||
# atomically reserves a slot, so all 3 tasks see each other's load immediately.
|
||||
response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
|
||||
responses = await asyncio.gather(response1_task, response2_task, response3_task)
|
||||
|
||||
|
|
@ -1074,17 +1067,9 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
|
|||
moe_reqs.append(moe_req)
|
||||
|
||||
# Generate 3 critiques
|
||||
critique1_endpoint = await choose_endpoint(model)
|
||||
critique1_task = asyncio.create_task(_make_chat_request(critique1_endpoint, model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
critique2_endpoint = await choose_endpoint(model)
|
||||
critique2_task = asyncio.create_task(_make_chat_request(critique2_endpoint, model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
critique3_endpoint = await choose_endpoint(model)
|
||||
critique3_task = asyncio.create_task(_make_chat_request(critique3_endpoint, model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
|
||||
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
|
||||
|
||||
|
|
@ -1092,8 +1077,7 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
|
|||
m = enhance.moe_select_candidate(query, critiques)
|
||||
|
||||
# Generate final response
|
||||
final_endpoint = await choose_endpoint(model)
|
||||
return await _make_chat_request(final_endpoint, model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
|
||||
def iso8601_ns():
|
||||
ns = time.time_ns()
|
||||
|
|
@ -1462,7 +1446,7 @@ async def get_usage_counts() -> Dict:
|
|||
# -------------------------------------------------------------
|
||||
# 5. Endpoint selection logic (respecting the configurable limit)
|
||||
# -------------------------------------------------------------
|
||||
async def choose_endpoint(model: str) -> str:
|
||||
async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
|
||||
"""
|
||||
Determine which endpoint to use for the given model while respecting
|
||||
the `max_concurrent_connections` per endpoint‑model pair **and**
|
||||
|
|
@ -1523,7 +1507,8 @@ async def choose_endpoint(model: str) -> str:
|
|||
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
|
||||
loaded_sets = await asyncio.gather(*load_tasks)
|
||||
|
||||
# Protect all reads of usage_counts with the lock
|
||||
# Protect all reads/writes of usage_counts with the lock so that selection
|
||||
# and reservation are atomic — concurrent callers see each other's pending load.
|
||||
async with usage_lock:
|
||||
# Helper: current usage for (endpoint, model) using the same normalized key
|
||||
# that increment_usage/decrement_usage store — raw model names differ from
|
||||
|
|
@ -1541,34 +1526,37 @@ async def choose_endpoint(model: str) -> str:
|
|||
# Sort ascending for load balancing — all endpoints here already have the
|
||||
# model loaded, so there is no model-switching cost to optimise for.
|
||||
loaded_and_free.sort(key=tracking_usage)
|
||||
|
||||
# When all candidates are equally idle, randomise to avoid always picking
|
||||
# the first entry in a stable sort.
|
||||
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
|
||||
return random.choice(loaded_and_free)
|
||||
selected = random.choice(loaded_and_free)
|
||||
else:
|
||||
selected = loaded_and_free[0]
|
||||
else:
|
||||
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
||||
endpoints_with_free_slot = [
|
||||
ep for ep in candidate_endpoints
|
||||
if tracking_usage(ep) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
return loaded_and_free[0]
|
||||
if endpoints_with_free_slot:
|
||||
# Sort by total endpoint load (ascending) to prefer idle endpoints.
|
||||
endpoints_with_free_slot.sort(
|
||||
key=lambda ep: sum(usage_counts.get(ep, {}).values())
|
||||
)
|
||||
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
|
||||
selected = random.choice(endpoints_with_free_slot)
|
||||
else:
|
||||
selected = endpoints_with_free_slot[0]
|
||||
else:
|
||||
# 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
|
||||
selected = min(candidate_endpoints, key=tracking_usage)
|
||||
|
||||
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
||||
endpoints_with_free_slot = [
|
||||
ep for ep in candidate_endpoints
|
||||
if tracking_usage(ep) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
if endpoints_with_free_slot:
|
||||
# Sort by total endpoint load (ascending) to prefer idle endpoints.
|
||||
endpoints_with_free_slot.sort(
|
||||
key=lambda ep: sum(usage_counts.get(ep, {}).values())
|
||||
)
|
||||
|
||||
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
|
||||
return random.choice(endpoints_with_free_slot)
|
||||
|
||||
return endpoints_with_free_slot[0]
|
||||
|
||||
# 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
|
||||
ep = min(candidate_endpoints, key=tracking_usage)
|
||||
return ep
|
||||
tracking_model = get_tracking_model(selected, model)
|
||||
if reserve:
|
||||
usage_counts[selected][tracking_model] += 1
|
||||
await publish_snapshot()
|
||||
return selected, tracking_model
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 6. API route – Generate
|
||||
|
|
@ -1609,10 +1597,8 @@ async def proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=error_msg) from e
|
||||
|
||||
|
||||
endpoint = await choose_endpoint(model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
|
|
@ -1637,7 +1623,6 @@ async def proxy(request: Request):
|
|||
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
|
||||
# 4. Async generator that streams data and decrements the counter
|
||||
async def stream_generate_response():
|
||||
|
|
@ -1732,10 +1717,8 @@ async def chat_proxy(request: Request):
|
|||
opt = True
|
||||
else:
|
||||
opt = False
|
||||
endpoint = await choose_endpoint(model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
|
|
@ -1766,7 +1749,6 @@ async def chat_proxy(request: Request):
|
|||
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
# 3. Async generator that streams chat data and decrements the counter
|
||||
async def stream_chat_response():
|
||||
try:
|
||||
|
|
@ -1858,10 +1840,8 @@ async def embedding_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
|
|
@ -1869,7 +1849,6 @@ async def embedding_proxy(request: Request):
|
|||
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
# 3. Async generator that streams embedding data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
|
|
@ -1926,10 +1905,8 @@ async def embed_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
|
|
@ -1937,7 +1914,6 @@ async def embed_proxy(request: Request):
|
|||
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
# 3. Async generator that streams embed data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
|
|
@ -2035,8 +2011,7 @@ async def show_proxy(request: Request, model: Optional[str] = None):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
#await increment_usage(endpoint, model)
|
||||
endpoint, _ = await choose_endpoint(model, reserve=False)
|
||||
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
|
|
@ -2111,22 +2086,10 @@ async def stats_proxy(request: Request, model: Optional[str] = None):
|
|||
status_code=404, detail="No token data found for this model"
|
||||
)
|
||||
|
||||
# Get time series data for the last 30 days (43200 minutes = 30 days)
|
||||
# Assuming entries are grouped by minute, 30 days = 43200 entries max
|
||||
time_series = []
|
||||
endpoint_totals = defaultdict(int) # Track tokens per endpoint
|
||||
|
||||
async for entry in db.get_latest_time_series(limit=50000):
|
||||
if entry['model'] == model:
|
||||
time_series.append({
|
||||
'endpoint': entry['endpoint'],
|
||||
'timestamp': entry['timestamp'],
|
||||
'input_tokens': entry['input_tokens'],
|
||||
'output_tokens': entry['output_tokens'],
|
||||
'total_tokens': entry['total_tokens']
|
||||
})
|
||||
# Accumulate total tokens per endpoint
|
||||
endpoint_totals[entry['endpoint']] += entry['total_tokens']
|
||||
time_series = [
|
||||
entry async for entry in db.get_time_series_for_model(model)
|
||||
]
|
||||
endpoint_distribution = await db.get_endpoint_distribution_for_model(model)
|
||||
|
||||
return {
|
||||
'model': model,
|
||||
|
|
@ -2134,7 +2097,7 @@ async def stats_proxy(request: Request, model: Optional[str] = None):
|
|||
'output_tokens': token_data['output_tokens'],
|
||||
'total_tokens': token_data['total_tokens'],
|
||||
'time_series': time_series,
|
||||
'endpoint_distribution': dict(endpoint_totals)
|
||||
'endpoint_distribution': endpoint_distribution,
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------
|
||||
|
|
@ -2418,8 +2381,10 @@ async def ps_proxy(request: Request):
|
|||
})
|
||||
|
||||
# 3. Return a JSONResponse with deduplicated currently deployed models
|
||||
# Deduplicate on 'name' rather than 'digest': llama-server models always
|
||||
# have digest="" so deduping on digest collapses all of them to one entry.
|
||||
return JSONResponse(
|
||||
content={"models": dedupe_on_keys(models['models'], ['digest'])},
|
||||
content={"models": dedupe_on_keys(models['models'], ['name'])},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
|
@ -2565,7 +2530,7 @@ async def config_proxy(request: Request):
|
|||
client: aiohttp.ClientSession = app_state["session"]
|
||||
headers = None
|
||||
if "/v1" in url:
|
||||
headers = {"Authorization": "Bearer " + config.api_keys[url]}
|
||||
headers = {"Authorization": "Bearer " + config.api_keys.get(url, "no-key")}
|
||||
target_url = f"{url}/models"
|
||||
else:
|
||||
target_url = f"{url}/api/version"
|
||||
|
|
@ -2625,10 +2590,7 @@ async def openai_embedding_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
if is_openai_compatible(endpoint):
|
||||
api_key = config.api_keys.get(endpoint, "no-key")
|
||||
else:
|
||||
|
|
@ -2637,13 +2599,16 @@ async def openai_embedding_proxy(request: Request):
|
|||
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key)
|
||||
|
||||
# 3. Async generator that streams embedding data and decrements the counter
|
||||
async_gen = await oclient.embeddings.create(input=doc, model=model)
|
||||
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 5. Return a StreamingResponse backed by the generator
|
||||
return async_gen
|
||||
try:
|
||||
async_gen = await oclient.embeddings.create(input=doc, model=model)
|
||||
result = async_gen.model_dump()
|
||||
for item in result.get("data", []):
|
||||
emb = item.get("embedding")
|
||||
if emb:
|
||||
item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb]
|
||||
return JSONResponse(content=result)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 22. API route – OpenAI compatible Chat Completions
|
||||
|
|
@ -2676,12 +2641,21 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
logprobs = payload.get("logprobs")
|
||||
top_logprobs = payload.get("top_logprobs")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not isinstance(messages, list):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'messages' (must be a list)"
|
||||
)
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
|
||||
params = {
|
||||
"messages": messages,
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
|
|
@ -2703,23 +2677,11 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
}
|
||||
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not isinstance(messages, list):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'messages' (must be a list)"
|
||||
)
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
# 3. Async generator that streams completions data and decrements the counter
|
||||
|
|
@ -2825,12 +2787,21 @@ async def openai_completions_proxy(request: Request):
|
|||
max_completion_tokens = payload.get("max_completion_tokens")
|
||||
suffix = payload.get("suffix")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'prompt'"
|
||||
)
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
|
|
@ -2849,23 +2820,11 @@ async def openai_completions_proxy(request: Request):
|
|||
}
|
||||
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'prompt'"
|
||||
)
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
|
||||
|
|
@ -3001,7 +2960,124 @@ async def openai_models_proxy(request: Request):
|
|||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 25. Serve the static front‑end
|
||||
# 25. API route – OpenAI/Jina/Cohere compatible Rerank
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/v1/rerank")
|
||||
@app.post("/rerank")
|
||||
async def rerank_proxy(request: Request):
|
||||
"""
|
||||
Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
|
||||
|
||||
Compatible with the Jina/Cohere rerank API convention used by llama-server,
|
||||
vLLM, and services such as Cohere and Jina AI.
|
||||
|
||||
Ollama does not natively support reranking; requests routed to a plain Ollama
|
||||
endpoint will receive a 501 Not Implemented response.
|
||||
|
||||
Request body:
|
||||
model (str, required) – reranker model name
|
||||
query (str, required) – search query
|
||||
documents (list[str], required) – candidate documents to rank
|
||||
top_n (int, optional) – limit returned results (default: all)
|
||||
return_documents (bool, optional) – include document text in results
|
||||
max_tokens_per_doc (int, optional) – truncation limit per document
|
||||
|
||||
Response (Jina/Cohere-compatible):
|
||||
{
|
||||
"id": "...",
|
||||
"model": "...",
|
||||
"usage": {"prompt_tokens": N, "total_tokens": N},
|
||||
"results": [{"index": 0, "relevance_score": 0.95}, ...]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
query = payload.get("query")
|
||||
documents = payload.get("documents")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'model'")
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'query'")
|
||||
if not isinstance(documents, list) or not documents:
|
||||
raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# Determine which endpoint serves this model
|
||||
try:
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
# Ollama endpoints have no native rerank support
|
||||
if not is_openai_compatible(endpoint):
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=(
|
||||
f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
|
||||
"reranking. Use a llama-server or OpenAI-compatible endpoint with a "
|
||||
"dedicated reranker model."
|
||||
),
|
||||
)
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")[0]
|
||||
|
||||
# Build upstream rerank request body – forward only recognised fields
|
||||
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
|
||||
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
|
||||
if optional_key in payload:
|
||||
upstream_payload[optional_key] = payload[optional_key]
|
||||
|
||||
# Determine upstream URL:
|
||||
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
|
||||
# External OpenAI endpoints expose /rerank under their /v1 base
|
||||
if endpoint in config.llama_server_endpoints:
|
||||
# llama-server: endpoint may or may not already contain /v1
|
||||
if "/v1" in endpoint:
|
||||
rerank_url = f"{endpoint}/rerank"
|
||||
else:
|
||||
rerank_url = f"{endpoint}/v1/rerank"
|
||||
else:
|
||||
# External OpenAI-compatible: ep2base gives us the /v1 base
|
||||
rerank_url = f"{ep2base(endpoint)}/rerank"
|
||||
|
||||
api_key = config.api_keys.get(endpoint, "no-key")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
try:
|
||||
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
|
||||
response_bytes = await resp.read()
|
||||
if resp.status >= 400:
|
||||
raise HTTPException(
|
||||
status_code=resp.status,
|
||||
detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
|
||||
)
|
||||
data = orjson.loads(response_bytes)
|
||||
|
||||
# Record token usage if the upstream returned a usage object
|
||||
usage = data.get("usage") or {}
|
||||
prompt_tok = usage.get("prompt_tokens") or 0
|
||||
total_tok = usage.get("total_tokens") or 0
|
||||
# For reranking there are no completion tokens; we record prompt tokens only
|
||||
if prompt_tok or total_tok:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
|
||||
|
||||
return JSONResponse(content=data)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 26. Serve the static front‑end
|
||||
# -------------------------------------------------------------
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
|
|
@ -3095,7 +3171,7 @@ async def usage_stream(request: Request):
|
|||
# -------------------------------------------------------------
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
global config, db
|
||||
global config, db, token_worker_task, flush_task
|
||||
# Load YAML config (or use defaults if not present)
|
||||
config_path = _config_path_from_env()
|
||||
config = Config.from_yaml(config_path)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue