diff --git a/README.md b/README.md index d7e08c1..1780c40 100644 --- a/README.md +++ b/README.md @@ -14,12 +14,17 @@ Copy/Clone the repository, edit the config.yaml by adding your Ollama backend se ``` # config.yaml +# Ollama or OpenAI API V1 endpoints endpoints: - http://ollama0:11434 - http://ollama1:11434 - http://ollama2:11434 - https://api.openai.com/v1 +# llama.cpp server endpoints +llama_server_endpoints: + - http://192.168.0.33:8889/v1 + # Maximum concurrent connections *per endpoint‑model pair* max_concurrent_connections: 2 @@ -34,6 +39,7 @@ api_keys: "http://192.168.0.51:11434": "ollama" "http://192.168.0.52:11434": "ollama" "https://api.openai.com/v1": "${OPENAI_KEY}" + "http://192.168.0.33:8889/v1": "llama" ``` Run the NOMYO Router in a dedicated virtual environment, install the requirements and run with uvicorn: @@ -58,6 +64,12 @@ finally you can uvicorn router:app --host 127.0.0.1 --port 12434 ``` +in very high concurrent scenarios (> 500 simultaneous requests) you can also run with uvloop + +``` +uvicorn router:app --host 127.0.0.1 --port 12434 --loop uvloop +``` + ## Docker Deployment Build the container image locally: @@ -98,7 +110,6 @@ This way the Ollama backend servers are utilized more efficient than by simply u NOMYO Router also supports OpenAI API compatible v1 backend servers. - ## Supplying the router API key If you set `nomyo-router-api-key` in `config.yaml` (or `NOMYO_ROUTER_API_KEY` env), every request to NOMYO Router must include the key: @@ -107,6 +118,7 @@ If you set `nomyo-router-api-key` in `config.yaml` (or `NOMYO_ROUTER_API_KEY` en - Query param (fallback): `?api_key=` Examples: + ```bash curl -H "Authorization: Bearer $NOMYO_ROUTER_API_KEY" http://localhost:12434/api/tags curl "http://localhost:12434/api/tags?api_key=$NOMYO_ROUTER_API_KEY" diff --git a/db.py b/db.py index 11df49c..af7b252 100644 --- a/db.py +++ b/db.py @@ -63,6 +63,7 @@ class TokenDatabase: ) ''') await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_timestamp ON token_time_series(timestamp)') + await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_model_ts ON token_time_series(model, timestamp)') await db.commit() async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int): @@ -178,6 +179,46 @@ class TokenDatabase: 'timestamp': row[5] } + async def get_time_series_for_model(self, model: str, limit: int = 50000): + """Get time series entries for a specific model, newest first. + + Uses the (model, timestamp) composite index so the DB does the filtering + instead of returning all rows and discarding non-matching ones in Python. + """ + db = await self._get_connection() + async with self._operation_lock: + async with db.execute(''' + SELECT endpoint, input_tokens, output_tokens, total_tokens, timestamp + FROM token_time_series + WHERE model = ? + ORDER BY timestamp DESC + LIMIT ? + ''', (model, limit)) as cursor: + async for row in cursor: + yield { + 'endpoint': row[0], + 'input_tokens': row[1], + 'output_tokens': row[2], + 'total_tokens': row[3], + 'timestamp': row[4], + } + + async def get_endpoint_distribution_for_model(self, model: str) -> dict: + """Return total tokens per endpoint for a specific model as a plain dict. + + Computed entirely in SQL so no Python-side aggregation is needed. + """ + db = await self._get_connection() + async with self._operation_lock: + async with db.execute(''' + SELECT endpoint, SUM(total_tokens) + FROM token_time_series + WHERE model = ? + GROUP BY endpoint + ''', (model,)) as cursor: + rows = await cursor.fetchall() + return {row[0]: row[1] for row in rows} + async def get_token_counts_for_model(self, model): """Get token counts for a specific model, aggregated across all endpoints.""" db = await self._get_connection() diff --git a/requirements.txt b/requirements.txt index aa51a0f..f1db419 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,5 +36,6 @@ tqdm==4.67.1 typing-inspection==0.4.1 typing_extensions==4.14.1 uvicorn==0.38.0 +uvloop yarl==1.20.1 aiosqlite diff --git a/router.py b/router.py index 1ef096a..691d11d 100644 --- a/router.py +++ b/router.py @@ -2,11 +2,11 @@ title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing author: alpha-nerd-nomyo author_url: https://github.com/nomyo-ai -version: 0.6 +version: 0.7 license: AGPL """ # ------------------------------------------------------------- -import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets +import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math try: import truststore; truststore.inject_into_ssl() except ImportError: @@ -75,7 +75,7 @@ def _mask_secrets(text: str) -> str: return text # OpenAI-style keys (sk-...) and generic "api key" mentions text = re.sub(r"sk-[A-Za-z0-9]{4}[A-Za-z0-9_-]*", "sk-***redacted***", text) - text = re.sub(r"(?i)(api[-_ ]key\\s*[:=]\\s*)([^\\s]+)", r"\\1***redacted***", text) + text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text) return text # ------------------------------------------------------------------ @@ -374,8 +374,11 @@ def _extract_llama_quant(name: str) -> str: def _is_llama_model_loaded(item: dict) -> bool: """Return True if a llama-server /v1/models item has status 'loaded'. - Handles both dict format ({"value": "loaded"}) and plain string ("loaded").""" + Handles both dict format ({"value": "loaded"}) and plain string ("loaded"). + If no status field is present, the model is always-loaded (not dynamically managed).""" status = item.get("status") + if status is None: + return True # No status field: model is always loaded (e.g. single-model servers) if isinstance(status, dict): return status.get("value") == "loaded" if isinstance(status, str): @@ -925,11 +928,12 @@ async def decrement_usage(endpoint: str, model: str) -> None: # usage_counts.pop(endpoint, None) await publish_snapshot() -async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse: +async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse: """ Helper function to make a chat request to a specific endpoint. Handles endpoint selection, client creation, usage tracking, and request execution. """ + endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves use_openai = is_openai_compatible(endpoint) if use_openai: if ":latest" in model: @@ -959,10 +963,6 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No else: client = ollama.AsyncClient(host=endpoint) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) - await increment_usage(endpoint, tracking_model) - try: if use_openai: start_ts = time.perf_counter() @@ -1054,18 +1054,11 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool moe_reqs = [] - # Generate 3 responses - response1_endpoint = await choose_endpoint(model) - response1_task = asyncio.create_task(_make_chat_request(response1_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) - await asyncio.sleep(0.01) # Small delay to allow usage count to update - - response2_endpoint = await choose_endpoint(model) - response2_task = asyncio.create_task(_make_chat_request(response2_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) - await asyncio.sleep(0.01) # Small delay to allow usage count to update - - response3_endpoint = await choose_endpoint(model) - response3_task = asyncio.create_task(_make_chat_request(response3_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) - await asyncio.sleep(0.01) # Small delay to allow usage count to update + # Generate 3 responses — choose_endpoint is called inside _make_chat_request and + # atomically reserves a slot, so all 3 tasks see each other's load immediately. + response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) responses = await asyncio.gather(response1_task, response2_task, response3_task) @@ -1074,17 +1067,9 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool moe_reqs.append(moe_req) # Generate 3 critiques - critique1_endpoint = await choose_endpoint(model) - critique1_task = asyncio.create_task(_make_chat_request(critique1_endpoint, model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) - await asyncio.sleep(0.01) # Small delay to allow usage count to update - - critique2_endpoint = await choose_endpoint(model) - critique2_task = asyncio.create_task(_make_chat_request(critique2_endpoint, model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) - await asyncio.sleep(0.01) # Small delay to allow usage count to update - - critique3_endpoint = await choose_endpoint(model) - critique3_task = asyncio.create_task(_make_chat_request(critique3_endpoint, model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) - await asyncio.sleep(0.01) # Small delay to allow usage count to update + critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task) @@ -1092,8 +1077,7 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool m = enhance.moe_select_candidate(query, critiques) # Generate final response - final_endpoint = await choose_endpoint(model) - return await _make_chat_request(final_endpoint, model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive) + return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive) def iso8601_ns(): ns = time.time_ns() @@ -1462,7 +1446,7 @@ async def get_usage_counts() -> Dict: # ------------------------------------------------------------- # 5. Endpoint selection logic (respecting the configurable limit) # ------------------------------------------------------------- -async def choose_endpoint(model: str) -> str: +async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: """ Determine which endpoint to use for the given model while respecting the `max_concurrent_connections` per endpoint‑model pair **and** @@ -1523,7 +1507,8 @@ async def choose_endpoint(model: str) -> str: load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] loaded_sets = await asyncio.gather(*load_tasks) - # Protect all reads of usage_counts with the lock + # Protect all reads/writes of usage_counts with the lock so that selection + # and reservation are atomic — concurrent callers see each other's pending load. async with usage_lock: # Helper: current usage for (endpoint, model) using the same normalized key # that increment_usage/decrement_usage store — raw model names differ from @@ -1541,34 +1526,37 @@ async def choose_endpoint(model: str) -> str: # Sort ascending for load balancing — all endpoints here already have the # model loaded, so there is no model-switching cost to optimise for. loaded_and_free.sort(key=tracking_usage) - # When all candidates are equally idle, randomise to avoid always picking # the first entry in a stable sort. if all(tracking_usage(ep) == 0 for ep in loaded_and_free): - return random.choice(loaded_and_free) + selected = random.choice(loaded_and_free) + else: + selected = loaded_and_free[0] + else: + # 4️⃣ Endpoints among the candidates that simply have a free slot + endpoints_with_free_slot = [ + ep for ep in candidate_endpoints + if tracking_usage(ep) < config.max_concurrent_connections + ] - return loaded_and_free[0] + if endpoints_with_free_slot: + # Sort by total endpoint load (ascending) to prefer idle endpoints. + endpoints_with_free_slot.sort( + key=lambda ep: sum(usage_counts.get(ep, {}).values()) + ) + if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): + selected = random.choice(endpoints_with_free_slot) + else: + selected = endpoints_with_free_slot[0] + else: + # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue) + selected = min(candidate_endpoints, key=tracking_usage) - # 4️⃣ Endpoints among the candidates that simply have a free slot - endpoints_with_free_slot = [ - ep for ep in candidate_endpoints - if tracking_usage(ep) < config.max_concurrent_connections - ] - - if endpoints_with_free_slot: - # Sort by total endpoint load (ascending) to prefer idle endpoints. - endpoints_with_free_slot.sort( - key=lambda ep: sum(usage_counts.get(ep, {}).values()) - ) - - if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): - return random.choice(endpoints_with_free_slot) - - return endpoints_with_free_slot[0] - - # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue) - ep = min(candidate_endpoints, key=tracking_usage) - return ep + tracking_model = get_tracking_model(selected, model) + if reserve: + usage_counts[selected][tracking_model] += 1 + await publish_snapshot() + return selected, tracking_model # ------------------------------------------------------------- # 6. API route – Generate @@ -1609,10 +1597,8 @@ async def proxy(request: Request): raise HTTPException(status_code=400, detail=error_msg) from e - endpoint = await choose_endpoint(model) + endpoint, tracking_model = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1637,7 +1623,6 @@ async def proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, tracking_model) # 4. Async generator that streams data and decrements the counter async def stream_generate_response(): @@ -1732,10 +1717,8 @@ async def chat_proxy(request: Request): opt = True else: opt = False - endpoint = await choose_endpoint(model) + endpoint, tracking_model = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1766,7 +1749,6 @@ async def chat_proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, tracking_model) # 3. Async generator that streams chat data and decrements the counter async def stream_chat_response(): try: @@ -1858,10 +1840,8 @@ async def embedding_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) + endpoint, tracking_model = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1869,7 +1849,6 @@ async def embedding_proxy(request: Request): client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, tracking_model) # 3. Async generator that streams embedding data and decrements the counter async def stream_embedding_response(): try: @@ -1926,10 +1905,8 @@ async def embed_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) + endpoint, tracking_model = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1937,7 +1914,6 @@ async def embed_proxy(request: Request): client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, tracking_model) # 3. Async generator that streams embed data and decrements the counter async def stream_embedding_response(): try: @@ -2035,8 +2011,7 @@ async def show_proxy(request: Request, model: Optional[str] = None): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) - #await increment_usage(endpoint, model) + endpoint, _ = await choose_endpoint(model, reserve=False) client = ollama.AsyncClient(host=endpoint) @@ -2111,22 +2086,10 @@ async def stats_proxy(request: Request, model: Optional[str] = None): status_code=404, detail="No token data found for this model" ) - # Get time series data for the last 30 days (43200 minutes = 30 days) - # Assuming entries are grouped by minute, 30 days = 43200 entries max - time_series = [] - endpoint_totals = defaultdict(int) # Track tokens per endpoint - - async for entry in db.get_latest_time_series(limit=50000): - if entry['model'] == model: - time_series.append({ - 'endpoint': entry['endpoint'], - 'timestamp': entry['timestamp'], - 'input_tokens': entry['input_tokens'], - 'output_tokens': entry['output_tokens'], - 'total_tokens': entry['total_tokens'] - }) - # Accumulate total tokens per endpoint - endpoint_totals[entry['endpoint']] += entry['total_tokens'] + time_series = [ + entry async for entry in db.get_time_series_for_model(model) + ] + endpoint_distribution = await db.get_endpoint_distribution_for_model(model) return { 'model': model, @@ -2134,7 +2097,7 @@ async def stats_proxy(request: Request, model: Optional[str] = None): 'output_tokens': token_data['output_tokens'], 'total_tokens': token_data['total_tokens'], 'time_series': time_series, - 'endpoint_distribution': dict(endpoint_totals) + 'endpoint_distribution': endpoint_distribution, } # ------------------------------------------------------------- @@ -2418,8 +2381,10 @@ async def ps_proxy(request: Request): }) # 3. Return a JSONResponse with deduplicated currently deployed models + # Deduplicate on 'name' rather than 'digest': llama-server models always + # have digest="" so deduping on digest collapses all of them to one entry. return JSONResponse( - content={"models": dedupe_on_keys(models['models'], ['digest'])}, + content={"models": dedupe_on_keys(models['models'], ['name'])}, status_code=200, ) @@ -2565,7 +2530,7 @@ async def config_proxy(request: Request): client: aiohttp.ClientSession = app_state["session"] headers = None if "/v1" in url: - headers = {"Authorization": "Bearer " + config.api_keys[url]} + headers = {"Authorization": "Bearer " + config.api_keys.get(url, "no-key")} target_url = f"{url}/models" else: target_url = f"{url}/api/version" @@ -2625,10 +2590,7 @@ async def openai_embedding_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) - await increment_usage(endpoint, tracking_model) + endpoint, tracking_model = await choose_endpoint(model) if is_openai_compatible(endpoint): api_key = config.api_keys.get(endpoint, "no-key") else: @@ -2637,13 +2599,16 @@ async def openai_embedding_proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key) - # 3. Async generator that streams embedding data and decrements the counter - async_gen = await oclient.embeddings.create(input=doc, model=model) - - await decrement_usage(endpoint, tracking_model) - - # 5. Return a StreamingResponse backed by the generator - return async_gen + try: + async_gen = await oclient.embeddings.create(input=doc, model=model) + result = async_gen.model_dump() + for item in result.get("data", []): + emb = item.get("embedding") + if emb: + item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb] + return JSONResponse(content=result) + finally: + await decrement_usage(endpoint, tracking_model) # ------------------------------------------------------------- # 22. API route – OpenAI compatible Chat Completions @@ -2676,12 +2641,21 @@ async def openai_chat_completions_proxy(request: Request): logprobs = payload.get("logprobs") top_logprobs = payload.get("top_logprobs") + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not isinstance(messages, list): + raise HTTPException( + status_code=400, detail="Missing required field 'messages' (must be a list)" + ) + if ":latest" in model: model = model.split(":latest") model = model[0] params = { - "messages": messages, + "messages": messages, "model": model, } @@ -2703,23 +2677,11 @@ async def openai_chat_completions_proxy(request: Request): } params.update({k: v for k, v in optional_params.items() if v is not None}) - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - if not isinstance(messages, list): - raise HTTPException( - status_code=400, detail="Missing required field 'messages' (must be a list)" - ) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) - await increment_usage(endpoint, tracking_model) + endpoint, tracking_model = await choose_endpoint(model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Async generator that streams completions data and decrements the counter @@ -2825,12 +2787,21 @@ async def openai_completions_proxy(request: Request): max_completion_tokens = payload.get("max_completion_tokens") suffix = payload.get("suffix") + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not prompt: + raise HTTPException( + status_code=400, detail="Missing required field 'prompt'" + ) + if ":latest" in model: model = model.split(":latest") model = model[0] params = { - "prompt": prompt, + "prompt": prompt, "model": model, } @@ -2849,23 +2820,11 @@ async def openai_completions_proxy(request: Request): } params.update({k: v for k, v in optional_params.items() if v is not None}) - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - if not prompt: - raise HTTPException( - status_code=400, detail="Missing required field 'prompt'" - ) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) - await increment_usage(endpoint, tracking_model) + endpoint, tracking_model = await choose_endpoint(model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) @@ -3001,7 +2960,124 @@ async def openai_models_proxy(request: Request): ) # ------------------------------------------------------------- -# 25. Serve the static front‑end +# 25. API route – OpenAI/Jina/Cohere compatible Rerank +# ------------------------------------------------------------- +@app.post("/v1/rerank") +@app.post("/rerank") +async def rerank_proxy(request: Request): + """ + Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint. + + Compatible with the Jina/Cohere rerank API convention used by llama-server, + vLLM, and services such as Cohere and Jina AI. + + Ollama does not natively support reranking; requests routed to a plain Ollama + endpoint will receive a 501 Not Implemented response. + + Request body: + model (str, required) – reranker model name + query (str, required) – search query + documents (list[str], required) – candidate documents to rank + top_n (int, optional) – limit returned results (default: all) + return_documents (bool, optional) – include document text in results + max_tokens_per_doc (int, optional) – truncation limit per document + + Response (Jina/Cohere-compatible): + { + "id": "...", + "model": "...", + "usage": {"prompt_tokens": N, "total_tokens": N}, + "results": [{"index": 0, "relevance_score": 0.95}, ...] + } + """ + try: + body_bytes = await request.body() + payload = orjson.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + query = payload.get("query") + documents = payload.get("documents") + + if not model: + raise HTTPException(status_code=400, detail="Missing required field 'model'") + if not query: + raise HTTPException(status_code=400, detail="Missing required field 'query'") + if not isinstance(documents, list) or not documents: + raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)") + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # Determine which endpoint serves this model + try: + endpoint, tracking_model = await choose_endpoint(model) + except RuntimeError as e: + raise HTTPException(status_code=404, detail=str(e)) + + # Ollama endpoints have no native rerank support + if not is_openai_compatible(endpoint): + await decrement_usage(endpoint, tracking_model) + raise HTTPException( + status_code=501, + detail=( + f"Endpoint '{endpoint}' is a plain Ollama instance which does not support " + "reranking. Use a llama-server or OpenAI-compatible endpoint with a " + "dedicated reranker model." + ), + ) + + if ":latest" in model: + model = model.split(":latest")[0] + + # Build upstream rerank request body – forward only recognised fields + upstream_payload: dict = {"model": model, "query": query, "documents": documents} + for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"): + if optional_key in payload: + upstream_payload[optional_key] = payload[optional_key] + + # Determine upstream URL: + # llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints) + # External OpenAI endpoints expose /rerank under their /v1 base + if endpoint in config.llama_server_endpoints: + # llama-server: endpoint may or may not already contain /v1 + if "/v1" in endpoint: + rerank_url = f"{endpoint}/rerank" + else: + rerank_url = f"{endpoint}/v1/rerank" + else: + # External OpenAI-compatible: ep2base gives us the /v1 base + rerank_url = f"{ep2base(endpoint)}/rerank" + + api_key = config.api_keys.get(endpoint, "no-key") + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + + client: aiohttp.ClientSession = app_state["session"] + try: + async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp: + response_bytes = await resp.read() + if resp.status >= 400: + raise HTTPException( + status_code=resp.status, + detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")), + ) + data = orjson.loads(response_bytes) + + # Record token usage if the upstream returned a usage object + usage = data.get("usage") or {} + prompt_tok = usage.get("prompt_tokens") or 0 + total_tok = usage.get("total_tokens") or 0 + # For reranking there are no completion tokens; we record prompt tokens only + if prompt_tok or total_tok: + await token_queue.put((endpoint, tracking_model, prompt_tok, 0)) + + return JSONResponse(content=data) + finally: + await decrement_usage(endpoint, tracking_model) + +# ------------------------------------------------------------- +# 26. Serve the static front‑end # ------------------------------------------------------------- app.mount("/static", StaticFiles(directory="static"), name="static") @@ -3095,7 +3171,7 @@ async def usage_stream(request: Request): # ------------------------------------------------------------- @app.on_event("startup") async def startup_event() -> None: - global config, db + global config, db, token_worker_task, flush_task # Load YAML config (or use defaults if not present) config_path = _config_path_from_env() config = Config.from_yaml(config_path)