diff --git a/.gitignore b/.gitignore index 7cd8431..cfce37c 100644 --- a/.gitignore +++ b/.gitignore @@ -66,4 +66,7 @@ config.yaml # SQLite *.db* -*settings.json \ No newline at end of file +*settings.json + +# Test suite (local only, not committed yet) +test/ \ No newline at end of file diff --git a/README.md b/README.md index ef3e6f2..1b952d9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # NOMYO Router -is a transparent proxy for [Ollama](https://github.com/ollama/ollama) with model deployment aware routing. +is a transparent proxy for inference engines, i.e. [Ollama](https://github.com/ollama/ollama), [llama.cpp](https://github.com/ggml-org/llama.cpp/), [vllm](https://github.com/vllm-project/vllm) or any OpenAI V1 compatible endpoint with model deployment aware routing. [![Click for video](https://github.com/user-attachments/assets/ddacdf88-e3f3-41dd-8be6-f165b22d9879)](https://eu1.nomyo.ai/assets/dash.mp4) diff --git a/config.yaml b/config.yaml index 4d7a5e4..76fbbe1 100644 --- a/config.yaml +++ b/config.yaml @@ -9,8 +9,23 @@ llama_server_endpoints: - http://192.168.0.50:8889/v1 # Maximum concurrent connections *per endpoint‑model pair* (equals to OLLAMA_NUM_PARALLEL) +# This is the global default; individual endpoints can override it via endpoint_config below. max_concurrent_connections: 2 +# Per-endpoint overrides (optional). Any field not listed falls back to the global default. +# endpoint_config: +# "http://192.168.0.50:11434": +# max_concurrent_connections: 3 +# "http://192.168.0.51:11434": +# max_concurrent_connections: 1 + +# Priority / WRR routing (optional, default: false). +# When true, requests are routed by utilization ratio (usage/max_concurrent_connections) +# and the config order of endpoints acts as the tiebreaker — the first endpoint listed +# is preferred when two endpoints are equally loaded. +# When false (default), equally-idle endpoints are chosen at random. +# priority_routing: true + # Optional router-level API key that gates router/API/web UI access (leave empty to disable) nomyo-router-api-key: "" diff --git a/router.py b/router.py index 9c02077..995eb78 100644 --- a/router.py +++ b/router.py @@ -6,7 +6,7 @@ version: 0.7 license: AGPL """ # ------------------------------------------------------------- -import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket +import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket, httpx try: import truststore; truststore.inject_into_ssl() except ImportError: @@ -185,6 +185,8 @@ _CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or b app_state = { "session": None, "connector": None, + "socket_sessions": {}, # endpoint -> aiohttp.ClientSession(UnixConnector) for .sock endpoints + "httpx_clients": {}, # endpoint -> httpx.AsyncClient(UDS transport) for .sock endpoints } token_worker_task: asyncio.Task | None = None flush_task: asyncio.Task | None = None @@ -216,6 +218,10 @@ class Config(BaseSettings): llama_server_endpoints: List[str] = Field(default_factory=list) # Max concurrent connections per endpoint‑model pair, see OLLAMA_NUM_PARALLEL max_concurrent_connections: int = 1 + # Per-endpoint overrides: {endpoint_url: {max_concurrent_connections: N}} + endpoint_config: Dict[str, Dict] = Field(default_factory=dict) + # When True, config order = priority; routes by utilization ratio + config index (WRR) + priority_routing: bool = Field(default=False) api_keys: Dict[str, str] = Field(default_factory=dict) # Optional router-level API key used to gate access to this service and dashboard @@ -494,6 +500,65 @@ def _extract_llama_quant(name: str) -> str: return name.rsplit(":", 1)[1] return "" + +def _is_unix_socket_endpoint(endpoint: str) -> bool: + """Return True if endpoint uses Unix socket (.sock hostname convention). + + Detects URLs like http://192.168.0.52.sock/v1 where the host ends with + .sock, indicating the connection should use a Unix domain socket at + /tmp/ instead of TCP. + """ + try: + host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0] + return host.endswith(".sock") + except IndexError: + return False + + +def _get_socket_path(endpoint: str) -> str: + """Derive Unix socket file path from a .sock endpoint URL. + + http://192.168.0.52.sock/v1 -> /run/user//192.168.0.52.sock + """ + host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0] + return f"/run/user/{os.getuid()}/{host}" + + +def get_session(endpoint: str) -> aiohttp.ClientSession: + """Return the appropriate aiohttp session for the given endpoint. + + Unix socket endpoints (.sock) get their own UnixConnector session. + All other endpoints share the main TCP session. + """ + if _is_unix_socket_endpoint(endpoint): + sess = app_state["socket_sessions"].get(endpoint) + if sess is not None: + return sess + return app_state["session"] + + +def _make_openai_client( + endpoint: str, + default_headers: dict | None = None, + api_key: str = "no-key", +) -> openai.AsyncOpenAI: + """Return an AsyncOpenAI client configured for the given endpoint. + + For Unix socket endpoints, injects a pre-created httpx UDS transport + so the OpenAI SDK connects via the socket instead of TCP. + """ + base_url = ep2base(endpoint) + kwargs: dict = {"api_key": api_key} + if default_headers is not None: + kwargs["default_headers"] = default_headers + if _is_unix_socket_endpoint(endpoint): + http_client = app_state["httpx_clients"].get(endpoint) + if http_client is not None: + kwargs["http_client"] = http_client + base_url = "http://localhost/v1" + return openai.AsyncOpenAI(base_url=base_url, **kwargs) + + def _is_llama_model_loaded(item: dict) -> bool: """Return True if a llama-server /v1/models item has status 'loaded'. Handles both dict format ({"value": "loaded"}) and plain string ("loaded"). @@ -733,7 +798,7 @@ class fetch: endpoint_url = f"{endpoint}/api/tags" key = "models" - client: aiohttp.ClientSession = app_state["session"] + client: aiohttp.ClientSession = get_session(endpoint) try: async with client.get(endpoint_url, headers=headers) as resp: await _ensure_success(resp) @@ -854,8 +919,8 @@ class fetch: For Ollama endpoints: queries /api/ps and returns model names For llama-server endpoints: queries /v1/models and filters for status.value == "loaded" """ - client: aiohttp.ClientSession = app_state["session"] - + client: aiohttp.ClientSession = get_session(endpoint) + # Check if this is a llama-server endpoint if endpoint in config.llama_server_endpoints: # Query /v1/models for llama-server @@ -997,7 +1062,7 @@ class fetch: if _is_fresh(_available_error_cache[endpoint], 300): return [] - client: aiohttp.ClientSession = app_state["session"] + client: aiohttp.ClientSession = get_session(endpoint) headers = None if api_key is not None: headers = {"Authorization": "Bearer " + api_key} @@ -1089,7 +1154,7 @@ async def _make_chat_request(model: str, messages: list, tools=None, stream: boo "response_format": {"type": "json_schema", "json_schema": format} if format is not None else None } params.update({k: v for k, v in optional_params.items() if v is not None}) - oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) @@ -1636,6 +1701,12 @@ async def get_usage_counts() -> Dict: # ------------------------------------------------------------- # 5. Endpoint selection logic (respecting the configurable limit) # ------------------------------------------------------------- +def get_max_connections(ep: str) -> int: + """Per-endpoint max_concurrent_connections, falling back to the global value.""" + return config.endpoint_config.get(ep, {}).get( + "max_concurrent_connections", config.max_concurrent_connections + ) + async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: """ Determine which endpoint to use for the given model while respecting @@ -1706,41 +1777,65 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: def tracking_usage(ep: str) -> int: return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0) + def utilization_ratio(ep: str) -> float: + return tracking_usage(ep) / get_max_connections(ep) + + # Priority map: position in all_endpoints list (lower = higher priority) + ep_priority = {ep: i for i, ep in enumerate(all_endpoints)} + # 3️⃣ Endpoints that have the model loaded *and* a free slot loaded_and_free = [ ep for ep, models in zip(candidate_endpoints, loaded_sets) - if model in models and tracking_usage(ep) < config.max_concurrent_connections + if model in models and tracking_usage(ep) < get_max_connections(ep) ] if loaded_and_free: - # Sort ascending for load balancing — all endpoints here already have the - # model loaded, so there is no model-switching cost to optimise for. - loaded_and_free.sort(key=tracking_usage) - # When all candidates are equally idle, randomise to avoid always picking - # the first entry in a stable sort. - if all(tracking_usage(ep) == 0 for ep in loaded_and_free): - selected = random.choice(loaded_and_free) - else: + if config.priority_routing: + # WRR: sort by config order first (stable), then by utilization ratio. + # Stable sort preserves priority for equal-ratio endpoints. + loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999)) + loaded_and_free.sort(key=utilization_ratio) selected = loaded_and_free[0] + else: + # Sort ascending for load balancing — all endpoints here already have the + # model loaded, so there is no model-switching cost to optimise for. + loaded_and_free.sort(key=tracking_usage) + # When all candidates are equally idle, randomise to avoid always picking + # the first entry in a stable sort. + if all(tracking_usage(ep) == 0 for ep in loaded_and_free): + selected = random.choice(loaded_and_free) + else: + selected = loaded_and_free[0] else: # 4️⃣ Endpoints among the candidates that simply have a free slot endpoints_with_free_slot = [ ep for ep in candidate_endpoints - if tracking_usage(ep) < config.max_concurrent_connections + if tracking_usage(ep) < get_max_connections(ep) ] if endpoints_with_free_slot: - # Sort by total endpoint load (ascending) to prefer idle endpoints. - endpoints_with_free_slot.sort( - key=lambda ep: sum(usage_counts.get(ep, {}).values()) - ) - if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): - selected = random.choice(endpoints_with_free_slot) - else: + if config.priority_routing: + endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999)) + endpoints_with_free_slot.sort(key=utilization_ratio) selected = endpoints_with_free_slot[0] + else: + # Sort by total endpoint load (ascending) to prefer idle endpoints. + endpoints_with_free_slot.sort( + key=lambda ep: sum(usage_counts.get(ep, {}).values()) + ) + if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): + selected = random.choice(endpoints_with_free_slot) + else: + selected = endpoints_with_free_slot[0] else: # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue) - selected = min(candidate_endpoints, key=tracking_usage) + if config.priority_routing: + selected = min( + candidate_endpoints, + key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)), + ) + else: + selected = min(candidate_endpoints, key=tracking_usage) tracking_model = get_tracking_model(selected, model) snapshot = None @@ -1822,7 +1917,7 @@ async def proxy(request: Request): "suffix": suffix, } params.update({k: v for k, v in optional_params.items() if v is not None}) - oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) @@ -2000,7 +2095,7 @@ async def chat_proxy(request: Request): "response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None } params.update({k: v for k, v in optional_params.items() if v is not None}) - oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) # For OpenAI endpoints: make the API call in handler scope @@ -2220,7 +2315,7 @@ async def embedding_proxy(request: Request): if ":latest" in model: model = model.split(":latest") model = model[0] - client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) + client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) # 3. Async generator that streams embedding data and decrements the counter @@ -2285,7 +2380,7 @@ async def embed_proxy(request: Request): if ":latest" in model: model = model.split(":latest") model = model[0] - client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) + client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) # 3. Async generator that streams embed data and decrements the counter @@ -2834,7 +2929,7 @@ async def ps_details_proxy(request: Request): # Fetch /props for each llama-server model to get context length (n_ctx) # and unload sleeping models automatically async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool, bool]: - client: aiohttp.ClientSession = app_state["session"] + client: aiohttp.ClientSession = get_session(endpoint) base_url = endpoint.rstrip("/").removesuffix("/v1") props_url = f"{base_url}/props?model={model_id}" headers = None @@ -2907,7 +3002,7 @@ async def config_proxy(request: Request): which endpoints are being proxied. """ async def check_endpoint(url: str): - client: aiohttp.ClientSession = app_state["session"] + client: aiohttp.ClientSession = get_session(url) headers = None if "/v1" in url: headers = {"Authorization": "Bearer " + config.api_keys.get(url, "no-key")} @@ -2990,9 +3085,7 @@ async def openai_embedding_proxy(request: Request): api_key = config.api_keys.get(endpoint, "no-key") else: api_key = "ollama" - base_url = ep2base(endpoint) - - oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key) + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key) try: async_gen = await oclient.embeddings.create(input=doc, model=model) @@ -3105,8 +3198,7 @@ async def openai_chat_completions_proxy(request: Request): # 2. Endpoint logic endpoint, tracking_model = await choose_endpoint(model) - base_url = ep2base(endpoint) - oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Helpers and API call — done in handler scope so try/except works reliably async def _normalize_images_in_messages(msgs: list) -> list: """Fetch remote image URLs and convert them to base64 data URLs so @@ -3416,8 +3508,7 @@ async def openai_completions_proxy(request: Request): # 2. Endpoint logic endpoint, tracking_model = await choose_endpoint(model) - base_url = ep2base(endpoint) - oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Async generator that streams completions data and decrements the counter # Make the API call in handler scope (try/except inside async generators is unreliable) @@ -3672,7 +3763,7 @@ async def rerank_proxy(request: Request): "Authorization": f"Bearer {api_key}", } - client: aiohttp.ClientSession = app_state["session"] + client: aiohttp.ClientSession = get_session(endpoint) try: async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp: response_bytes = await resp.read() @@ -3845,7 +3936,9 @@ async def startup_event() -> None: f"Loaded configuration from {config_path}:\n" f" endpoints={config.endpoints},\n" f" llama_server_endpoints={config.llama_server_endpoints},\n" - f" max_concurrent_connections={config.max_concurrent_connections}" + f" max_concurrent_connections={config.max_concurrent_connections},\n" + f" endpoint_config={config.endpoint_config},\n" + f" priority_routing={config.priority_routing}" ) else: print( @@ -3874,6 +3967,19 @@ async def startup_event() -> None: app_state["connector"] = connector app_state["session"] = session + + # Create per-endpoint Unix socket sessions for .sock endpoints + for ep in config.llama_server_endpoints: + if _is_unix_socket_endpoint(ep): + sock_path = _get_socket_path(ep) + sock_connector = aiohttp.UnixConnector(path=sock_path) + sock_timeout = aiohttp.ClientTimeout(total=300, connect=5, sock_read=300) + sock_session = aiohttp.ClientSession(connector=sock_connector, timeout=sock_timeout) + app_state["socket_sessions"][ep] = sock_session + transport = httpx.AsyncHTTPTransport(uds=sock_path) + app_state["httpx_clients"][ep] = httpx.AsyncClient(transport=transport, timeout=300.0) + print(f"[startup] Unix socket session: {ep} -> {sock_path}") + token_worker_task = asyncio.create_task(token_worker()) flush_task = asyncio.create_task(flush_buffer()) await init_llm_cache(config) @@ -3883,6 +3989,23 @@ async def shutdown_event() -> None: await close_all_sse_queues() await flush_remaining_buffers() await app_state["session"].close() + + # Close Unix socket sessions + for ep, sess in list(app_state.get("socket_sessions", {}).items()): + try: + await sess.close() + print(f"[shutdown] Closed Unix socket session: {ep}") + except Exception as e: + print(f"[shutdown] Error closing Unix socket session {ep}: {e}") + + # Close httpx Unix socket clients + for ep, client in list(app_state.get("httpx_clients", {}).items()): + try: + await client.aclose() + print(f"[shutdown] Closed httpx client: {ep}") + except Exception as e: + print(f"[shutdown] Error closing httpx client {ep}: {e}") + if token_worker_task is not None: token_worker_task.cancel() if flush_task is not None: