diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..a0d14ed --- /dev/null +++ b/config.yaml @@ -0,0 +1,8 @@ +# config.yaml +endpoints: + - http://192.168.0.50:11434 + - http://192.168.0.51:11434 + - http://192.168.0.52:11434 + +# Maximum concurrent connections *per endpoint‑model pair* +max_concurrent_connections: 2 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..988d1df --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +annotated-types==0.7.0 +anyio==4.10.0 +certifi==2025.8.3 +click==8.2.1 +exceptiongroup==1.3.0 +fastapi==0.116.1 +h11==0.16.0 +httpcore==1.0.9 +httpx==0.28.1 +idna==3.10 +ollama==0.5.3 +pydantic==2.11.7 +pydantic-settings==2.10.1 +pydantic_core==2.33.2 +python-dotenv==1.1.1 +PyYAML==6.0.2 +sniffio==1.3.1 +starlette==0.47.2 +typing-inspection==0.4.1 +typing_extensions==4.14.1 +uvicorn==0.35.0 diff --git a/router.py b/router.py new file mode 100644 index 0000000..3e864db --- /dev/null +++ b/router.py @@ -0,0 +1,753 @@ +""" +title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing +author: alpha-nerd-nomyo +author_url: https://github.com/nomyo-ai +version: 0.1 +license: AGPL +""" +# ------------------------------------------------------------- +import json, random, asyncio, yaml, httpx, ollama +from pathlib import Path +from typing import Dict, Set, List +from fastapi import FastAPI, Request, HTTPException +from starlette.responses import StreamingResponse, JSONResponse, Response +from pydantic import Field +from pydantic_settings import BaseSettings +from collections import defaultdict + +# ------------------------------------------------------------- +# 1. Configuration loader +# ------------------------------------------------------------- +class Config(BaseSettings): + # List of Ollama endpoints + endpoints: list[str] = Field( + default_factory=lambda: [ + "http://localhost:11434", + ] + ) + # Max concurrent connections per endpoint‑model pair, see OLLAMA_NUM_PARALLEL + max_concurrent_connections: int = 1 + + class Config: + # Load from `config.yaml` first, then from env variables + env_prefix = "OLLAMA_PROXY_" + yaml_file = Path("config.yaml") # relative to cwd + + @classmethod + def from_yaml(cls, path: Path) -> "Config": + """Load the YAML file and create the Config instance.""" + if path.exists(): + with path.open("r", encoding="utf-8") as fp: + data = yaml.safe_load(fp) or {} + return cls(**data) + return cls() + +# Create the global config object – it will be overwritten on startup +config = Config() + +# ------------------------------------------------------------- +# 2. FastAPI application +# ------------------------------------------------------------- +app = FastAPI() + +# ------------------------------------------------------------- +# 3. Global state: per‑endpoint per‑model active connection counters +# ------------------------------------------------------------- +usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) +usage_lock = asyncio.Lock() # protects access to usage_counts + +# ------------------------------------------------------------- +# 4. Helperfunctions +# ------------------------------------------------------------- +async def fetch_loaded_models(endpoint: str) -> Set[str]: + """ + Query /api/ps and return a set of model names that are currently + loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty + set is returned. + """ + try: + async with httpx.AsyncClient(timeout=1.0) as client: + resp = await client.get(f"{endpoint}/api/ps") + resp.raise_for_status() + data = resp.json() + # The response format is: + # {"models": [{"name": "model1"}, {"name": "model2"}]} + models = {m.get("name") for m in data.get("models", []) if m.get("name")} + return models + except Exception: + # If anything goes wrong we simply assume the endpoint has no models + return set() + +async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List[dict]: + """ + Query / to fetch and return a List of dicts with details + for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail. + """ + try: + async with httpx.AsyncClient(timeout=1.0) as client: + resp = await client.get(f"{endpoint}{route}") + resp.raise_for_status() + data = resp.json() + detail = data.get(detail) + return detail + except Exception: + # If anything goes wrong we cannot reply versions + return {detail: "N/A"} + +def dedupe_on_keys(dicts, key_fields): + """ + Helper function to deduplicate endpoint details based on given dict keys. + """ + seen = set() + out = [] + for d in dicts: + # Build a tuple of the values for the chosen keys + key = tuple(d.get(k) for k in key_fields) + if key not in seen: + seen.add(key) + out.append(d) + return out + +# ------------------------------------------------------------- +# 5. Endpoint selection logic (respecting the configurable limit) +# ------------------------------------------------------------- +async def choose_endpoint(model: str) -> str: + """ + Determine which endpoint to use for the given model while respecting the + `max_concurrent_connections` per endpoint‑model pair. + + The selection algorithm is as follows: + + 1. Find all endpoints that have the model *already loaded* and that still + have a free slot (< max_concurrent_connections). + 2. If none exist, find any endpoint that has a free slot regardless of + whether the model is loaded – the endpoint will load the model on demand. + 3. If all endpoints are at capacity for this model, pick any endpoint + arbitrarily – the request will queue on that endpoint. + """ + # Gather loaded models for all endpoints concurrently + tasks = [fetch_loaded_models(ep) for ep in config.endpoints] + loaded_sets = await asyncio.gather(*tasks) + + async with usage_lock: + # 1️⃣ Endpoints that have the model loaded *and* a free slot + loaded_and_free = [ + ep for ep, models in zip(config.endpoints, loaded_sets) + if model in models and usage_counts[ep].get(model, 0) < config.max_concurrent_connections + ] + + if loaded_and_free: + # Prefer an endpoint that already hosts the model and has capacity + return random.choice(loaded_and_free) + + # 2️⃣ Endpoints that simply have a free slot (model may or may not be loaded) + endpoints_with_free_slot = [ + ep for ep in config.endpoints + if usage_counts[ep].get(model, 0) < config.max_concurrent_connections + ] + + if endpoints_with_free_slot: + return random.choice(endpoints_with_free_slot) + + # 3️⃣ All endpoints are at capacity – pick any (will queue on that endpoint according to OLLAMA_MAX_QUEUE) + return random.choice(config.endpoints) + +async def increment_usage(endpoint: str, model: str) -> None: + async with usage_lock: + usage_counts[endpoint][model] += 1 + +async def decrement_usage(endpoint: str, model: str) -> None: + async with usage_lock: + # Avoid negative counts + current = usage_counts[endpoint].get(model, 0) + if current > 0: + usage_counts[endpoint][model] = current - 1 + # Optionally, clean up zero entries + if usage_counts[endpoint].get(model, 0) == 0: + usage_counts[endpoint].pop(model, None) + if not usage_counts[endpoint]: + usage_counts.pop(endpoint, None) + +# ------------------------------------------------------------- +# 6. API route – Generate +# ------------------------------------------------------------- +@app.post("/api/generate") +async def proxy(request: Request): + """ + Proxy a generate request to Ollama and stream the response back to the client. + """ + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = json.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + prompt = payload.get("prompt") + suffix = payload.get("suffix") + system = payload.get("system") + template = payload.get("template") + context = payload.get("context") + stream = payload.get("stream") + think = payload.get("think") + raw = payload.get("raw") + format = payload.get("format") + images = payload.get("images") + options = payload.get("options") + keep_alive = payload.get("keep_alive") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not prompt: + raise HTTPException( + status_code=400, detail="Missing required field 'prompt'" + ) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Decide which endpoint to use + endpoint = await choose_endpoint(model) + + # Increment usage counter for this endpoint‑model pair + await increment_usage(endpoint, model) + + # 3. Create Ollama client instance + client = ollama.AsyncClient(host=endpoint) + + # 4. Async generator that streams data and decrements the counter + async def stream_generate_response(): + try: + async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=format, images=images, options=options, keep_alive=keep_alive) + if stream == True: + async for chunk in async_gen: + if hasattr(chunk, "model_dump_json"): + json_line = chunk.model_dump_json() + else: + json_line = json.dumps(chunk) + yield json_line.encode("utf-8") + b"\n" + else: + json_line = ( + async_gen.model_dump_json() + if hasattr(async_gen, "model_dump_json") + else json.dumps(async_gen) + ) + yield json_line.encode("utf-8") + b"\n" + + finally: + # Ensure counter is decremented even if an exception occurs + await decrement_usage(endpoint, model) + + # 5. Return a StreamingResponse backed by the generator + return StreamingResponse( + stream_generate_response(), + media_type="application/json", + ) + +# ------------------------------------------------------------- +# 7. API route – Chat +# ------------------------------------------------------------- +@app.post("/api/chat") +async def chat_proxy(request: Request): + """ + Proxy a chat request to Ollama and stream the endpoint reply. + """ + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = json.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + messages = payload.get("messages") + tools = payload.get("tools") + stream = payload.get("stream") + think = payload.get("think") + format = payload.get("format") + options = payload.get("options") + keep_alive = payload.get("keep_alive") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not isinstance(messages, list): + raise HTTPException( + status_code=400, detail="Missing or invalid 'message' field (must be a list)" + ) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Endpoint logic + endpoint = await choose_endpoint(model) + await increment_usage(endpoint, model) + client = ollama.AsyncClient(host=endpoint) + + # 3. Async generator that streams chat data and decrements the counter + async def stream_chat_response(): + try: + # The chat method returns a generator of dicts (or GenerateResponse) + async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive) + if stream == True: + async for chunk in async_gen: + # `chunk` can be a dict or a pydantic model – dump to JSON safely + if hasattr(chunk, "model_dump_json"): + json_line = chunk.model_dump_json() + else: + json_line = json.dumps(chunk) + yield json_line.encode("utf-8") + b"\n" + else: + json_line = ( + async_gen.model_dump_json() + if hasattr(async_gen, "model_dump_json") + else json.dumps(async_gen) + ) + yield json_line.encode("utf-8") + b"\n" + + finally: + # Ensure counter is decremented even if an exception occurs + await decrement_usage(endpoint, model) + + # 4. Return a StreamingResponse backed by the generator + return StreamingResponse( + stream_chat_response(), + media_type="application/json", + ) + +# ------------------------------------------------------------- +# 8. API route – Embedding +# ------------------------------------------------------------- +@app.post("/api/embeddings") +async def embedding_proxy(request: Request): + """ + Proxy an embedding request to Ollama and reply with embeddings. + + """ + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = json.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + prompt = payload.get("prompt") + options = payload.get("options") + keep_alive = payload.get("keep_alive") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not prompt: + raise HTTPException( + status_code=400, detail="Missing required field 'prompt'" + ) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Endpoint logic + endpoint = await choose_endpoint(model) + await increment_usage(endpoint, model) + client = ollama.AsyncClient(host=endpoint) + + # 3. Async generator that streams embedding data and decrements the counter + async def stream_embedding_response(): + try: + # The chat method returns a generator of dicts (or GenerateResponse) + async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive) + if hasattr(async_gen, "model_dump_json"): + json_line = async_gen.model_dump_json() + else: + json_line = json.dumps(async_gen) + yield json_line.encode("utf-8") + b"\n" + finally: + # Ensure counter is decremented even if an exception occurs + await decrement_usage(endpoint, model) + + # 5. Return a StreamingResponse backed by the generator + return StreamingResponse( + stream_embedding_response(), + media_type="application/json", + ) + +# ------------------------------------------------------------- +# 8. API route – Embed +# ------------------------------------------------------------- +@app.post("/api/embed") +async def embed_proxy(request: Request): + """ + Proxy an embed request to Ollama and reply with embeddings. + + """ + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = json.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + input = payload.get("input") + truncate = payload.get("truncate") + options = payload.get("options") + keep_alive = payload.get("keep_alive") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not input: + raise HTTPException( + status_code=400, detail="Missing required field 'input'" + ) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Endpoint logic + endpoint = await choose_endpoint(model) + await increment_usage(endpoint, model) + client = ollama.AsyncClient(host=endpoint) + + # 3. Async generator that streams embed data and decrements the counter + async def stream_embedding_response(): + try: + # The chat method returns a generator of dicts (or GenerateResponse) + async_gen = await client.embed(model=model, input=input, truncate=truncate, options=options, keep_alive=keep_alive) + if hasattr(async_gen, "model_dump_json"): + json_line = async_gen.model_dump_json() + else: + json_line = json.dumps(async_gen) + yield json_line.encode("utf-8") + b"\n" + finally: + # Ensure counter is decremented even if an exception occurs + await decrement_usage(endpoint, model) + + # 4. Return a StreamingResponse backed by the generator + return StreamingResponse( + stream_embedding_response(), + media_type="application/json", + ) + +# ------------------------------------------------------------- +# 9. API route – Create +# ------------------------------------------------------------- +@app.post("/api/create") +async def create_proxy(request: Request): + """ + Proxy a create request to all Ollama endpoints and reply with deduplicated status. + """ + try: + body_bytes = await request.body() + payload = json.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + quantize = payload.get("quantize") + from_ = payload.get("from") + files = payload.get("files") + adapters = payload.get("adapters") + template = payload.get("template") + license = payload.get("license") + system = payload.get("system") + parameters = payload.get("parameters") + messages = payload.get("messages") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not from_ and not files: + raise HTTPException( + status_code=400, detail="You need to provide either from_ or files parameter!" + ) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + status_lists = [] + for endpoint in config.endpoints: + client = ollama.AsyncClient(host=endpoint) + create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False) + status_lists.append(create) + + combined_status = [] + for status_list in status_lists: + combined_status += status_list + + final_status = list(dict.fromkeys(combined_status)) + + return dict(final_status) + +# ------------------------------------------------------------- +# 10. API route – Show +# ------------------------------------------------------------- +@app.post("/api/show") +async def show_proxy(request: Request): + """ + Proxy a model show request to Ollama and reply with ShowResponse. + + """ + try: + body_bytes = await request.body() + payload = json.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Endpoint logic + endpoint = await choose_endpoint(model) + await increment_usage(endpoint, model) + client = ollama.AsyncClient(host=endpoint) + + # 3. Proxy a simple show request + show = await client.show(model=model) + + # 4. Return ShowResponse + return show + +# ------------------------------------------------------------- +# 11. API route – Copy +# ------------------------------------------------------------- +@app.post("/api/copy") +async def copy_proxy(request: Request): + """ + Proxy a model copy request to each Ollama endpoint and reply with Status Code. + + """ + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = json.loads(body_bytes.decode("utf-8")) + + src = payload.get("source") + dst = payload.get("destination") + + if not src: + raise HTTPException( + status_code=400, detail="Missing required field 'source'" + ) + if not dst: + raise HTTPException( + status_code=400, detail="Missing required field 'destination'" + ) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 3. Iterate over all endpoints to copy the model on each endpoint + status_list = [] + for endpoint in config.endpoints: + client = ollama.AsyncClient(host=endpoint) + # 4. Proxy a simple copy request + copy = await client.copy(source=src, destination=dst) + status_list.append(copy.status) + + # 4. Return with 200 OK if all went well, 404 if a single endpoint failed + if 404 in status_list: + return Response( + status_code=404 + ) + else: + return Response( + status_code=200 + ) + +# ------------------------------------------------------------- +# 12. API route – Delete +# ------------------------------------------------------------- +@app.delete("/api/delete") +async def delete_proxy(request: Request): + """ + Proxy a model delete request to each Ollama endpoint and reply with Status Code. + + """ + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = json.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Iterate over all endpoints to delete the model on each endpoint + status_list = [] + for endpoint in config.endpoints: + client = ollama.AsyncClient(host=endpoint) + # 3. Proxy a simple copy request + copy = await client.delete(model=model) + status_list.append(copy.status) + + # 4. Retrun 200 0K, if a single enpoint fails, respond with 404 + if 404 in status_list: + return Response( + status_code=404 + ) + else: + return Response( + status_code=200 + ) + +# ------------------------------------------------------------- +# 13. API route – Pull +# ------------------------------------------------------------- +@app.post("/api/pull") +async def pull_proxy(request: Request): + """ + Proxy a pull request to all Ollama endpoint and report status back. + """ + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = json.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + insecure = payload.get("insecure") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Iterate over all endpoints to pull the model + status_list = [] + for endpoint in config.endpoints: + client = ollama.AsyncClient(host=endpoint) + # 3. Proxy a simple pull request + pull = await client.pull(model=model, insecure=insecure, stream=False) + status_list.append(pull) + + combined_status = [] + for status in status_list: + combined_status += status + + # 4. Report back a deduplicated status message + final_status = list(dict.fromkeys(combined_status)) + + return dict(final_status) + +# ------------------------------------------------------------- +# 14. API route – Push +# ------------------------------------------------------------- +@app.post("/api/push") +async def push_proxy(request: Request): + """ + Proxy a push request to Ollama and respond the deduplicated Ollama endpoint replies. + """ + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = json.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + insecure = payload.get("insecure") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Iterate over all endpoints + status_list = [] + for endpoint in config.endpoints: + client = ollama.AsyncClient(host=endpoint) + # 3. Proxy a simple push request + push = await client.push(model=model, insecure=insecure, stream=False) + status_list.append(push) + + combined_status = [] + for status in status_list: + combined_status += status + + # 4. Report a deduplicated status + final_status = list(dict.fromkeys(combined_status)) + + return dict(final_status) + + +# ------------------------------------------------------------- +# 15. API route – Version +# ------------------------------------------------------------- +@app.get("/api/version") +async def version_proxy(request: Request): + """ + Proxy a version request to Ollama and reply lowest version of all endpoints. + + """ + # 1. Query all endpoints for version + tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints] + all_versions = await asyncio.gather(*tasks) + + def version_key(v): + return tuple(map(int, v.split('.'))) + + # 2. Return a JSONResponse with the min Version of all endpoints to maintain compatibility + return JSONResponse( + content={"version": str(min(all_versions, key=version_key))}, + status_code=200, + ) + +# ------------------------------------------------------------- +# 16. API route – tags +# ------------------------------------------------------------- +@app.get("/api/tags") +async def tags_proxy(request: Request): + """ + Proxy a tags request to Ollama endpoints and reply with a unique list of all models. + + """ + # 1. Query all endpoints for models + tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints] + all_models = await asyncio.gather(*tasks) + + models = {'models': []} + for modellist in all_models: + models['models'] += modellist + + # 2. Return a JSONResponse with a deduplicated list of unique models for inference + return JSONResponse( + content={"models": dedupe_on_keys(models['models'], ['digest','name'])}, + status_code=200, + ) + +# ------------------------------------------------------------- +# 17. API route – ps +# ------------------------------------------------------------- +@app.get("/api/ps") +async def ps_proxy(request: Request): + """ + Proxy a ps request to all Ollama endpoints and reply a unique list of all running models. + + """ + # 1. Query all endpoints for running models + tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints] + loaded_models = await asyncio.gather(*tasks) + + models = {'models': []} + for modellist in loaded_models: + models['models'] += modellist + + # 25. Return a JSONResponse with deduplicated currently deployed models + return JSONResponse( + content={"models": dedupe_on_keys(models['models'], ['digest'])}, + status_code=200, + ) + +# ------------------------------------------------------------- +# 18. FastAPI startup event – load configuration +# ------------------------------------------------------------- +@app.on_event("startup") +async def startup_event() -> None: + global config + # Load YAML config (or use defaults if not present) + config = Config.from_yaml(Path("config.yaml")) + print(f"Loaded configuration:\n endpoints={config.endpoints},\n " + f"max_concurrent_connections={config.max_concurrent_connections}") \ No newline at end of file