diff --git a/requirements.txt b/requirements.txt index 53f3b9a..83ac385 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,11 +14,13 @@ fastapi-sse==1.1.1 frozenlist==1.7.0 h11==0.16.0 httpcore==1.0.9 +httpx==0.28.1 idna==3.10 jiter==0.10.0 multidict==6.6.4 ollama==0.5.3 openai==1.102.0 +pillow==11.3.0 propcache==0.3.2 pydantic==2.11.7 pydantic-settings==2.10.1 diff --git a/router.py b/router.py index 2fa59ef..4f6503f 100644 --- a/router.py +++ b/router.py @@ -2,11 +2,11 @@ title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing author: alpha-nerd-nomyo author_url: https://github.com/nomyo-ai -version: 0.3 +version: 0.4 license: AGPL """ # ------------------------------------------------------------- -import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random +import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io from pathlib import Path from typing import Dict, Set, List, Optional from fastapi import FastAPI, Request, HTTPException @@ -17,6 +17,7 @@ from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLR from pydantic import Field from pydantic_settings import BaseSettings from collections import defaultdict +from PIL import Image # ------------------------------------------------------------------ # In‑memory caches @@ -101,9 +102,9 @@ app.add_middleware( allow_headers=["Authorization", "Content-Type"], ) default_headers={ - "HTTP-Referer": "https://nomyo.ai", - "X-Title": "NOMYO Router", - } + "HTTP-Referer": "https://nomyo.ai", + "X-Title": "NOMYO Router", + } # ------------------------------------------------------------- # 3. Global state: per‑endpoint per‑model active connection counters @@ -114,8 +115,6 @@ usage_lock = asyncio.Lock() # protects access to usage_counts # ------------------------------------------------------------- # 4. Helperfunctions # ------------------------------------------------------------- -aiotimeout = aiohttp.ClientTimeout(total=5) - def _is_fresh(cached_at: float, ttl: int) -> bool: return (time.time() - cached_at) < ttl @@ -123,6 +122,20 @@ async def _ensure_success(resp: aiohttp.ClientResponse) -> None: if resp.status >= 400: text = await resp.text() raise HTTPException(status_code=resp.status, detail=text) + +def is_ext_openai_endpoint(endpoint: str) -> bool: + if "/v1" not in endpoint: + return False + + base_endpoint = endpoint.replace('/v1', '') + if base_endpoint in config.endpoints: + return False # It's Ollama's /v1 + + # Check for default Ollama port + if ':11434' in endpoint: + return False # It's Ollama + + return True # It's an external OpenAI endpoint class fetch: async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: @@ -274,6 +287,79 @@ def iso8601_ns(): ) return iso8601_with_ns +def is_base64(image_string): + try: + if isinstance(image_string, str) and base64.b64encode(base64.b64decode(image_string)) == image_string.encode(): + return True + except Exception as e: + return False + +def resize_image_if_needed(image_data): + try: + # Check if already data-url + if image_data.startswith("data:"): + try: + header, image_data = image_data.split(",", 1) + except ValueError: + pass + # Decode the base64 image data + image_bytes = base64.b64decode(image_data) + image = Image.open(io.BytesIO(image_bytes)) + if image.mode not in ("RGB", "L"): + image = image.convert("RGB") + + # Get current size + width, height = image.size + + # Calculate the new dimensions while maintaining aspect ratio + if width > 512 or height > 512: + aspect_ratio = width / height + if aspect_ratio > 1: # Width is larger + new_width = 512 + new_height = int(512 / aspect_ratio) + else: # Height is larger + new_height = 512 + new_width = int(512 * aspect_ratio) + + image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # Encode the resized image back to base64 + buffered = io.BytesIO() + image.save(buffered, format="PNG") + resized_image_data = base64.b64encode(buffered.getvalue()).decode("utf-8") + return resized_image_data + + except Exception as e: + print(f"Error processing image: {e}") + return None + +def transform_images_to_data_urls(message_list): + for message in message_list: + if "images" in message: + images = message.pop("images") + if not isinstance(images, list): + continue + new_content = [] + for image in images: #TODO: quality downsize if images are too big to fit into model context window size + if not is_base64(image): + raise ValueError(f"Image string is not a valid base64 encoded string.") + resized_image = resize_image_if_needed(image) + if resized_image: + data_url = f"data:image/png;base64,{resized_image}" + #new_content.append({ + # "type": "text", + # "text": "" + #}) + new_content.append({ + "type": "image_url", + "image_url": { + "url": data_url + } + }) + message["content"] = new_content + + return message_list + class rechunk: def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse: if chunk.choices == [] and chunk.usage is not None: @@ -328,12 +414,12 @@ class rechunk: created_at=iso8601_ns(), done=True if chunk.usage is not None else False, done_reason=chunk.choices[0].finish_reason, - total_duration=int((time.perf_counter() - start_ts) * 1000) if chunk.usage is not None else 0, + total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, load_duration=10000, prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0, prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0, eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0, - eval_duration=int((time.perf_counter() - start_ts) * 1000) if chunk.usage is not None else 0, + eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, response=chunk.choices[0].text or '', thinking=thinking) return rechunk @@ -436,8 +522,13 @@ async def choose_endpoint(model: str) -> str: # 6️⃣ if not candidate_endpoints: if ":latest" in model: #ollama naming convention not applicable to openai - model = model.split(":latest") - model = model[0] + model_without_latest = model.split(":latest")[0] + candidate_endpoints = [ + ep for ep, models in zip(config.endpoints, advertised_sets) + if model_without_latest in models and is_ext_openai_endpoint(ep) + ] + if not candidate_endpoints: + model = model + ":latest" candidate_endpoints = [ ep for ep, models in zip(config.endpoints, advertised_sets) if model in models @@ -516,7 +607,8 @@ async def proxy(request: Request): status_code=400, detail="Missing required field 'prompt'" ) except json.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted." + raise HTTPException(status_code=400, detail=error_msg) from e endpoint = await choose_endpoint(model) @@ -539,7 +631,7 @@ async def proxy(request: Request): "stop": options.get("stop") if options and "stop" in options else None, "top_p": options.get("top_p") if options and "top_p" in options else None, "temperature": options.get("temperature") if options and "temperature" in options else None, - "sufix": suffix, + "suffix": suffix, } params.update({k: v for k, v in optional_params.items() if v is not None}) oclient = openai.AsyncOpenAI(base_url=endpoint, default_headers=default_headers, api_key=config.api_keys[endpoint]) @@ -608,7 +700,7 @@ async def chat_proxy(request: Request): _format = payload.get("format") keep_alive = payload.get("keep_alive") options = payload.get("options") - + if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" @@ -631,6 +723,8 @@ async def chat_proxy(request: Request): if ":latest" in model: model = model.split(":latest") model = model[0] + if messages: + messages = transform_images_to_data_urls(messages) params = { "messages": messages, "model": model, @@ -664,11 +758,9 @@ async def chat_proxy(request: Request): async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive) if stream == True: async for chunk in async_gen: - print(chunk) if is_openai_endpoint: chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts) # `chunk` can be a dict or a pydantic model – dump to JSON safely - print(chunk) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: @@ -988,7 +1080,7 @@ async def delete_proxy(request: Request, model: Optional[str] = None): copy = await client.delete(model=model) status_list.append(copy.status) - # 4. Retrun 200 0K, if a single enpoint fails, respond with 404 + # 4. Return 200 0K, if a single enpoint fails, respond with 404 return Response(status_code=404 if 404 in status_list else 200) # ------------------------------------------------------------- @@ -1316,7 +1408,6 @@ async def openai_chat_completions_proxy(request: Request): await increment_usage(endpoint, model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys[endpoint]) - # 3. Async generator that streams completions data and decrements the counter async def stream_ochat_response(): try: @@ -1331,8 +1422,7 @@ async def openai_chat_completions_proxy(request: Request): ) if chunk.choices[0].delta.content is not None: yield f"data: {data}\n\n".encode("utf-8") - # Final DONE event - #yield b"data: [DONE]\n\n" + yield b"data: [DONE]\n\n" else: json_line = ( async_gen.model_dump_json() @@ -1580,7 +1670,7 @@ async def startup_event() -> None: ssl_context = ssl.create_default_context() connector = aiohttp.TCPConnector(limit=0, limit_per_host=512, ssl=ssl_context) - timeout = aiohttp.ClientTimeout(total=5, connect=5, sock_read=120, sock_connect=5) + timeout = aiohttp.ClientTimeout(total=60, connect=15, sock_read=120, sock_connect=15) session = aiohttp.ClientSession(connector=connector, timeout=timeout) app_state["connector"] = connector @@ -1590,4 +1680,4 @@ async def startup_event() -> None: @app.on_event("shutdown") async def shutdown_event() -> None: await close_all_sse_queues() - await app_state["session"].close() \ No newline at end of file + await app_state["session"].close() diff --git a/static/index.html b/static/index.html index dd0fcd1..ef6119b 100644 --- a/static/index.html +++ b/static/index.html @@ -424,76 +424,6 @@ }); }); - /* show logic */ - document.body.addEventListener("click", async (e) => { - if (!e.target.matches(".show-link")) return; - e.preventDefault(); - const model = e.target.dataset.model; - try { - const resp = await fetch( - `/api/show?model=${encodeURIComponent(model)}`, - { method: "POST" }, - ); - if (!resp.ok) - throw new Error(`Status ${resp.status}`); - const data = await resp.json(); - document.getElementById("json-output").textContent = - JSON.stringify(data, null, 2).replace( - /\\n/g, - "\n", - ); - document.getElementById( - "show-modal", - ).style.display = "flex"; - } catch (err) { - console.error(err); - alert( - `Could not load model details: ${err.message}`, - ); - } - }); - - /* pull logic */ - document - .getElementById("pull-btn") - .addEventListener("click", async () => { - const model = document - .getElementById("pull-model-input") - .value.trim(); - const statusEl = - document.getElementById("pull-status"); - if (!model) { - alert("Please enter a model name."); - return; - } - try { - const resp = await fetch( - `/api/pull?model=${encodeURIComponent(model)}`, - { method: "POST" }, - ); - if (!resp.ok) - throw new Error(`Status ${resp.status}`); - const data = await resp.json(); - statusEl.textContent = `✅ ${data.status}`; - statusEl.style.color = "green"; - loadTags(); - } catch (err) { - console.error(err); - statusEl.textContent = `❌ ${err.message}`; - statusEl.style.color = "red"; - } - }); - - /* modal close */ - const modal = document.getElementById("show-modal"); - modal.addEventListener("click", (e) => { - if ( - e.target === modal || - e.target.matches(".close-btn") - ) { - modal.style.display = "none"; - } - }); } catch (e) { console.error(e); } @@ -592,6 +522,77 @@ loadUsage(); setInterval(loadPS, 60_000); setInterval(loadEndpoints, 300_000); + + /* show logic */ + document.body.addEventListener("click", async (e) => { + if (!e.target.matches(".show-link")) return; + e.preventDefault(); + const model = e.target.dataset.model; + try { + const resp = await fetch( + `/api/show?model=${encodeURIComponent(model)}`, + { method: "POST" }, + ); + if (!resp.ok) + throw new Error(`Status ${resp.status}`); + const data = await resp.json(); + document.getElementById("json-output").textContent = + JSON.stringify(data, null, 2).replace( + /\\n/g, + "\n", + ); + document.getElementById( + "show-modal", + ).style.display = "flex"; + } catch (err) { + console.error(err); + alert( + `Could not load model details: ${err.message}`, + ); + } + }); + + /* pull logic */ + document + .getElementById("pull-btn") + .addEventListener("click", async () => { + const model = document + .getElementById("pull-model-input") + .value.trim(); + const statusEl = + document.getElementById("pull-status"); + if (!model) { + alert("Please enter a model name."); + return; + } + try { + const resp = await fetch( + `/api/pull?model=${encodeURIComponent(model)}`, + { method: "POST" }, + ); + if (!resp.ok) + throw new Error(`Status ${resp.status}`); + const data = await resp.json(); + statusEl.textContent = `✅ ${data.status}`; + statusEl.style.color = "green"; + loadTags(); + } catch (err) { + console.error(err); + statusEl.textContent = `❌ ${err.message}`; + statusEl.style.color = "red"; + } + }); + + /* modal close */ + const modal = document.getElementById("show-modal"); + modal.addEventListener("click", (e) => { + if ( + e.target === modal || + e.target.matches(".close-btn") + ) { + modal.style.display = "none"; + } + }); });