diff --git a/.gitignore b/.gitignore index 5dac518..4bb65cd 100644 --- a/.gitignore +++ b/.gitignore @@ -63,3 +63,6 @@ cython_debug/ # Logfile(s) *.log *.sqlite3 + +# Config +config.yaml \ No newline at end of file diff --git a/README.md b/README.md index 6ca4cd0..f6cc58f 100644 --- a/README.md +++ b/README.md @@ -18,9 +18,19 @@ endpoints: - http://ollama0:11434 - http://ollama1:11434 - http://ollama2:11434 + - https://api.openai.com/v1 # Maximum concurrent connections *per endpoint‑model pair* max_concurrent_connections: 2 + +# API keys for remote endpoints +# Set an environment variable like OPENAI_KEY +# Confirm endpoints are exactly as in endpoints block +api_keys: + "http://192.168.0.50:11434": "ollama" + "http://192.168.0.51:11434": "ollama" + "http://192.168.0.52:11434": "ollama" + "https://api.openai.com/v1": "${OPENAI_KEY}" ``` Run the NOMYO Router in a dedicated virtual environment, install the requirements and run with uvicorn: @@ -30,6 +40,13 @@ python3 -m venv .venv/router source .venv/router/bin/activate pip3 install -r requirements.txt ``` + +on the shell do: + +``` +export OPENAI_KEY=YOUR_SECRET_API_KEY +``` + finally you can ``` diff --git a/config.yaml b/config.yaml index 93ae117..bb8e8f5 100644 --- a/config.yaml +++ b/config.yaml @@ -3,8 +3,7 @@ endpoints: - http://192.168.0.50:11434 - http://192.168.0.51:11434 - http://192.168.0.52:11434 - #- https://openrouter.ai/api/v1 - #- https://api.openai.com/v1 + - https://api.openai.com/v1 # Maximum concurrent connections *per endpoint‑model pair* (equals to OLLAMA_NUM_PARALLEL) max_concurrent_connections: 2 @@ -16,5 +15,4 @@ api_keys: "http://192.168.0.50:11434": "ollama" "http://192.168.0.51:11434": "ollama" "http://192.168.0.52:11434": "ollama" - #"https://openrouter.ai/api/v1": "${OPENROUTER_KEY}" - #"https://api.openai.com/v1": "${OPENAI_KEY}" + "https://api.openai.com/v1": "${OPENAI_KEY}" diff --git a/requirements.txt b/requirements.txt index 4ffd391..df8c11a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,8 +14,6 @@ fastapi-sse==1.1.1 frozenlist==1.7.0 h11==0.16.0 httpcore==1.0.9 -httpx==0.28.1 -httpx-aiohttp==0.1.8 idna==3.10 jiter==0.10.0 multidict==6.6.4 diff --git a/router.py b/router.py index 51cf51d..374ccbc 100644 --- a/router.py +++ b/router.py @@ -2,17 +2,17 @@ title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing author: alpha-nerd-nomyo author_url: https://github.com/nomyo-ai -version: 0.2.2 +version: 0.3 license: AGPL """ # ------------------------------------------------------------- -import json, time, asyncio, yaml, httpx, ollama, openai, os, re -from httpx_aiohttp import AiohttpTransport +import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random from pathlib import Path from typing import Dict, Set, List, Optional from fastapi import FastAPI, Request, HTTPException from fastapi_sse import sse_handler from fastapi.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse, RedirectResponse from pydantic import Field from pydantic_settings import BaseSettings @@ -33,6 +33,14 @@ _error_cache: dict[str, float] = {} _subscribers: Set[asyncio.Queue] = set() _subscribers_lock = asyncio.Lock() +# ------------------------------------------------------------------ +# aiohttp Global Sessions +# ------------------------------------------------------------------ +app_state = { + "session": None, + "connector": None, +} + # ------------------------------------------------------------- # 1. Configuration loader # ------------------------------------------------------------- @@ -85,7 +93,13 @@ config = Config() # ------------------------------------------------------------- app = FastAPI() sse_handler.app = app - +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "DELETE"], + allow_headers=["Authorization", "Content-Type"], +) # ------------------------------------------------------------- # 3. Global state: per‑endpoint per‑model active connection counters # ------------------------------------------------------------- @@ -95,128 +109,115 @@ usage_lock = asyncio.Lock() # protects access to usage_counts # ------------------------------------------------------------- # 4. Helperfunctions # ------------------------------------------------------------- +aiotimeout = aiohttp.ClientTimeout(total=5) + def _is_fresh(cached_at: float, ttl: int) -> bool: return (time.time() - cached_at) < ttl -def get_httpx_client(endpoint: str) -> httpx.AsyncClient: - """ - Use persistent connections to request endpoint info for reliable results - in high load situations or saturated endpoints. - """ - return httpx.AsyncClient( - base_url=endpoint, - timeout=httpx.Timeout(5.0, read=5.0, write=None, connect=5.0), - #limits=httpx.Limits( - # max_keepalive_connections=64, - # max_connections=64 - #), - transport=AiohttpTransport() - ) +async def _ensure_success(resp: aiohttp.ClientResponse) -> None: + if resp.status >= 400: + text = await resp.text() + raise HTTPException(status_code=resp.status, detail=text) -async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: - """ - Query /api/tags and return a set of all model names that the - endpoint *advertises* (i.e. is capable of serving). This endpoint lists - every model that is installed on the Ollama instance, regardless of - whether the model is currently loaded into memory. +class fetch: + async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: + """ + Query /api/tags and return a set of all model names that the + endpoint *advertises* (i.e. is capable of serving). This endpoint lists + every model that is installed on the Ollama instance, regardless of + whether the model is currently loaded into memory. - If the request fails (e.g. timeout, 5xx, or malformed response), an empty - set is returned. - """ - headers = None - if api_key is not None: - headers = {"Authorization": "Bearer " + api_key} + If the request fails (e.g. timeout, 5xx, or malformed response), an empty + set is returned. + """ + headers = None + if api_key is not None: + headers = {"Authorization": "Bearer " + api_key} - if endpoint in _models_cache: - models, cached_at = _models_cache[endpoint] - if _is_fresh(cached_at, 300): - return models + if endpoint in _models_cache: + models, cached_at = _models_cache[endpoint] + if _is_fresh(cached_at, 300): + return models + else: + # stale entry – drop it + del _models_cache[endpoint] + + if endpoint in _error_cache: + if _is_fresh(_error_cache[endpoint], 1): + # Still within the short error TTL – pretend nothing is available + return set() + else: + # Error expired – remove it + del _error_cache[endpoint] + + if "/v1" in endpoint: + endpoint_url = f"{endpoint}/models" + key = "data" else: - # stale entry – drop it - del _models_cache[endpoint] + endpoint_url = f"{endpoint}/api/tags" + key = "models" + client: aiohttp.ClientSession = app_state["session"] + try: + async with client.get(endpoint_url, headers=headers) as resp: + await _ensure_success(resp) + data = await resp.json() - if endpoint in _error_cache: - if _is_fresh(_error_cache[endpoint], 1): - # Still within the short error TTL – pretend nothing is available + items = data.get(key, []) + models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")} + + if models: + _models_cache[endpoint] = (models, time.time()) + return models + else: + # Empty list – treat as “no models”, but still cache for 300s + _models_cache[endpoint] = (models, time.time()) + return models + except Exception as e: + # Treat any error as if the endpoint offers no models + print(f"[fetch.available_models] {endpoint} error: {e}") + _error_cache[endpoint] = time.time() return set() - else: - # Error expired – remove it - del _error_cache[endpoint] - try: - client = get_httpx_client(endpoint) - if "/v1" in endpoint: - resp = await client.get(f"/models", headers=headers) - else: - resp = await client.get(f"/api/tags") - resp.raise_for_status() - data = resp.json() - # Expected format: - # {"models": [{"name": "model1"}, {"name": "model2"}]} - if "/v1" in endpoint: - models = {m.get("id") for m in data.get("data", []) if m.get("id")} - else: + + async def loaded_models(endpoint: str) -> Set[str]: + """ + Query /api/ps and return a set of model names that are currently + loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty + set is returned. + """ + client: aiohttp.ClientSession = app_state["session"] + try: + async with client.get(f"{endpoint}/api/ps") as resp: + await _ensure_success(resp) + data = await resp.json() + # The response format is: + # {"models": [{"name": "model1"}, {"name": "model2"}]} models = {m.get("name") for m in data.get("models", []) if m.get("name")} + return models + except Exception: + # If anything goes wrong we simply assume the endpoint has no models + return set() + + async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]: + """ + Query / to fetch and return a List of dicts with details + for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail. + """ + client: aiohttp.ClientSession = app_state["session"] + headers = None + if api_key is not None: + headers = {"Authorization": "Bearer " + api_key} - if models: - _models_cache[endpoint] = (models, time.time()) - return models - else: - # Empty list – treat as “no models”, but still cache for 300s - _models_cache[endpoint] = (models, time.time()) - return models - except Exception as e: - # Treat any error as if the endpoint offers no models - print(f"[fetch_available_models] {endpoint} error: {e}") - _error_cache[endpoint] = time.time() - return set() - finally: - await client.aclose() - - -async def fetch_loaded_models(endpoint: str) -> Set[str]: - """ - Query /api/ps and return a set of model names that are currently - loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty - set is returned. - """ - try: - client = get_httpx_client(endpoint) - resp = await client.get(f"/api/ps") - resp.raise_for_status() - data = resp.json() - # The response format is: - # {"models": [{"name": "model1"}, {"name": "model2"}]} - models = {m.get("name") for m in data.get("models", []) if m.get("name")} - return models - except Exception: - # If anything goes wrong we simply assume the endpoint has no models - return set() - finally: - await client.aclose() - -async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]: - """ - Query / to fetch and return a List of dicts with details - for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail. - """ - headers = None - if api_key is not None: - headers = {"Authorization": "Bearer " + api_key} - - try: - client = get_httpx_client(endpoint) - resp = await client.get(f"{route}", headers=headers) - resp.raise_for_status() - data = resp.json() - detail = data.get(detail, []) - return detail - except Exception as e: - # If anything goes wrong we cannot reply details - print(e) - return [] - finally: - await client.aclose() + try: + async with client.get(f"{endpoint}{route}", headers=headers) as resp: + await _ensure_success(resp) + data = await resp.json() + detail = data.get(detail, []) + return detail + except Exception as e: + # If anything goes wrong we cannot reply details + print(e) + return [] def ep2base(ep): if "/v1" in ep: @@ -257,6 +258,71 @@ async def decrement_usage(endpoint: str, model: str) -> None: # usage_counts.pop(endpoint, None) await publish_snapshot() +def iso8601_ns(): + ns_since_epoch = time.time_ns() + dt = datetime.datetime.fromtimestamp( + ns_since_epoch / 1_000_000_000, # seconds + tz=datetime.timezone.utc + ) + iso8601_with_ns = ( + dt.strftime("%Y-%m-%dT%H:%M:%S.") + f"{ns_since_epoch % 1_000_000_000:09d}Z" + ) + return iso8601_with_ns + +class rechunk: + def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float): + rechunk = { "model": chunk.model, + "created_at": iso8601_ns() , + "done_reason": chunk.choices[0].finish_reason, + "load_duration": None, + "prompt_eval_count": None, + "prompt_eval_duration": None, + "eval_count": None, + "eval_duration": None, + "eval_count": (chunk.usage.completion_tokens if chunk.usage is not None else None), + "prompt_eval_count": (chunk.usage.prompt_tokens if chunk.usage is not None else None), + "eval_duration": (int((time.perf_counter() - start_ts) * 1000) if chunk.usage is not None else None), + "response_token/s": (round(chunk.usage.total_tokens / (time.perf_counter() - start_ts), 2) if chunk.usage is not None else None) + } + if stream == True: + rechunk["message"] = {"role": chunk.choices[0].delta.role or "assistant", "content": chunk.choices[0].delta.content, "thinking": None, "images": None, "tool_name": None, "tool_calls": None} + else: + rechunk["message"] = {"role": chunk.choices[0].message.role or "assistant", "content": chunk.choices[0].message.content, "thinking": None, "images": None, "tool_name": None, "tool_calls": None} + return rechunk + + def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float): + with_thinking = chunk.choices[0] if chunk.choices[0] else None + thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None + rechunk = { "model": chunk.model, + "created_at": iso8601_ns(), + "load_duration": None, + "done_reason": chunk.choices[0].finish_reason, + "total_duration": None, + "eval_duration": (int((time.perf_counter() - start_ts) * 1000) if chunk.usage is not None else None), + "thinking": thinking, + "context": None, + "response": chunk.choices[0].text + } + return rechunk + + def openai_embeddings2ollama(chunk: dict): + rechunk = {"embedding": chunk.data[0].embedding} + return rechunk + + def openai_embed2ollama(chunk: dict, model: str): + rechunk = { "model": model, + "created_at": iso8601_ns(), + "done": None, + "done_reason": None, + "total_duration": None, + "load_duration": None, + "prompt_eval_count": None, + "prompt_eval_duration": None, + "eval_count": None, + "eval_duration": None, + "embeddings": [chunk.data[0].embedding] + } + return rechunk # ------------------------------------------------------------------ # SSE Helpser # ------------------------------------------------------------------ @@ -269,6 +335,11 @@ async def publish_snapshot(): continue await q.put(snapshot) +async def close_all_sse_queues(): + for q in list(_subscribers): + # sentinel value that the generator will recognise + await q.put(None) + # ------------------------------------------------------------------ # Subscriber helpers # ------------------------------------------------------------------ @@ -305,16 +376,17 @@ async def choose_endpoint(model: str) -> str: 1️⃣ Query every endpoint for its advertised models (`/api/tags`). 2️⃣ Build a list of endpoints that contain the requested model. 3️⃣ For those endpoints, find those that have the model loaded - (`/api/ps`) *and* still have a free slot. + (`/api/ps`) *and* still have a free slot. 4️⃣ If none are both loaded and free, fall back to any endpoint - from the filtered list that simply has a free slot. + from the filtered list that simply has a free slot and randomly + select one. 5️⃣ If all are saturated, pick any endpoint from the filtered list - (the request will queue on that endpoint). + (the request will queue on that endpoint). 6️⃣ If no endpoint advertises the model at all, raise an error. """ # 1️⃣ Gather advertised‑model sets for all endpoints concurrently - tag_tasks = [fetch_available_models(ep) for ep in config.endpoints if "/v1" not in ep] - tag_tasks += [fetch_available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] + tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if "/v1" not in ep] + tag_tasks += [fetch.available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] advertised_sets = await asyncio.gather(*tag_tasks) # 2️⃣ Filter endpoints that advertise the requested model @@ -325,16 +397,24 @@ async def choose_endpoint(model: str) -> str: # 6️⃣ if not candidate_endpoints: - raise RuntimeError( - f"None of the configured endpoints ({', '.join(config.endpoints)}) " - f"advertise the model '{model}'." - ) + if ":latest" in model: #ollama naming convention not applicable to openai + model = model.split(":latest") + model = model[0] + candidate_endpoints = [ + ep for ep, models in zip(config.endpoints, advertised_sets) + if model in models + ] + if not candidate_endpoints: + raise RuntimeError( + f"None of the configured endpoints ({', '.join(config.endpoints)}) " + f"advertise the model '{model}'." + ) # 3️⃣ Among the candidates, find those that have the model *loaded* # (concurrently, but only for the filtered list) - load_tasks = [fetch_loaded_models(ep) for ep in candidate_endpoints] + load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] loaded_sets = await asyncio.gather(*load_tasks) - + async with usage_lock: # Helper: get current usage count for (endpoint, model) def current_usage(ep: str) -> int: @@ -343,7 +423,7 @@ async def choose_endpoint(model: str) -> str: # 3️⃣ Endpoints that have the model loaded *and* a free slot loaded_and_free = [ ep for ep, models in zip(candidate_endpoints, loaded_sets) - if model in models and usage_counts[ep].get(model, 0) < config.max_concurrent_connections + if model in models and usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections ] if loaded_and_free: @@ -353,12 +433,11 @@ async def choose_endpoint(model: str) -> str: # 4️⃣ Endpoints among the candidates that simply have a free slot endpoints_with_free_slot = [ ep for ep in candidate_endpoints - if usage_counts[ep].get(model, 0) < config.max_concurrent_connections + if usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections ] if endpoints_with_free_slot: - ep = min(endpoints_with_free_slot, key=current_usage) - return ep + return random.choice(endpoints_with_free_slot) # 5️⃣ All candidate endpoints are saturated – pick one with lowest usages count (will queue) ep = min(candidate_endpoints, key=current_usage) @@ -372,7 +451,6 @@ async def proxy(request: Request): """ Proxy a generate request to Ollama and stream the response back to the client. """ - # 1. Parse and validate request try: body_bytes = await request.body() payload = json.loads(body_bytes.decode("utf-8")) @@ -386,7 +464,7 @@ async def proxy(request: Request): stream = payload.get("stream") think = payload.get("think") raw = payload.get("raw") - format = payload.get("format") + _format = payload.get("format") images = payload.get("images") options = payload.get("options") keep_alive = payload.get("keep_alive") @@ -402,29 +480,53 @@ async def proxy(request: Request): except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - # 2. Decide which endpoint to use + endpoint = await choose_endpoint(model) + is_openai_endpoint = "/v1" in endpoint + if is_openai_endpoint: + if ":latest" in model: + model = model.split(":latest") + model = model[0] + params = { + "prompt": prompt, + "model": model, + } - # Increment usage counter for this endpoint‑model pair + optional_params = { + "stream": stream, + } + + params.update({k: v for k, v in optional_params.items() if v is not None}) + oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint]) + else: + client = ollama.AsyncClient(host=endpoint) await increment_usage(endpoint, model) - # 3. Create Ollama client instance - client = ollama.AsyncClient(host=endpoint) - # 4. Async generator that streams data and decrements the counter async def stream_generate_response(): try: - async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=format, images=images, options=options, keep_alive=keep_alive) + if is_openai_endpoint: + start_ts = time.perf_counter() + async_gen = await oclient.completions.create(**params) + else: + async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive) if stream == True: async for chunk in async_gen: + if is_openai_endpoint: + chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: json_line = json.dumps(chunk) yield json_line.encode("utf-8") + b"\n" else: + if is_openai_endpoint: + response = rechunk.openai_completion2ollama(async_gen, stream, start_ts) + response = json.dumps(response) + else: + response = async_gen.model_dump_json() json_line = ( - async_gen.model_dump_json() + response if hasattr(async_gen, "model_dump_json") else json.dumps(async_gen) ) @@ -468,23 +570,46 @@ async def chat_proxy(request: Request): ) if not isinstance(messages, list): raise HTTPException( - status_code=400, detail="Missing or invalid 'message' field (must be a list)" + status_code=400, detail="Missing or invalid 'messages' field (must be a list)" ) except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic endpoint = await choose_endpoint(model) - await increment_usage(endpoint, model) - client = ollama.AsyncClient(host=endpoint) + is_openai_endpoint = "/v1" in endpoint + if is_openai_endpoint: + if ":latest" in model: + model = model.split(":latest") + model = model[0] + params = { + "messages": messages, + "model": model, + } + optional_params = { + "tools": tools, + "stream": stream, + } + + params.update({k: v for k, v in optional_params.items() if v is not None}) + oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint]) + else: + client = ollama.AsyncClient(host=endpoint) + await increment_usage(endpoint, model) # 3. Async generator that streams chat data and decrements the counter async def stream_chat_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) - async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive) + if is_openai_endpoint: + start_ts = time.perf_counter() + async_gen = await oclient.chat.completions.create(**params) + else: + async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive) if stream == True: async for chunk in async_gen: + if is_openai_endpoint: + chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts) # `chunk` can be a dict or a pydantic model – dump to JSON safely if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() @@ -492,8 +617,13 @@ async def chat_proxy(request: Request): json_line = json.dumps(chunk) yield json_line.encode("utf-8") + b"\n" else: + if is_openai_endpoint: + response = rechunk.openai_chat_completion2ollama(async_gen, stream, start_ts) + response = json.dumps(response) + else: + response = async_gen.model_dump_json() json_line = ( - async_gen.model_dump_json() + response if hasattr(async_gen, "model_dump_json") else json.dumps(async_gen) ) @@ -541,14 +671,24 @@ async def embedding_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) + is_openai_endpoint = "/v1" in endpoint + if is_openai_endpoint: + if ":latest" in model: + model = model.split(":latest") + model = model[0] + client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint]) + else: + client = ollama.AsyncClient(host=endpoint) await increment_usage(endpoint, model) - client = ollama.AsyncClient(host=endpoint) - # 3. Async generator that streams embedding data and decrements the counter async def stream_embedding_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) - async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive) + if is_openai_endpoint: + async_gen = await client.embeddings.create(input=[prompt], model=model) + async_gen = rechunk.openai_embeddings2ollama(async_gen) + else: + async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive) if hasattr(async_gen, "model_dump_json"): json_line = async_gen.model_dump_json() else: @@ -579,7 +719,7 @@ async def embed_proxy(request: Request): payload = json.loads(body_bytes.decode("utf-8")) model = payload.get("model") - input = payload.get("input") + _input = payload.get("input") truncate = payload.get("truncate") options = payload.get("options") keep_alive = payload.get("keep_alive") @@ -588,7 +728,7 @@ async def embed_proxy(request: Request): raise HTTPException( status_code=400, detail="Missing required field 'model'" ) - if not input: + if not _input: raise HTTPException( status_code=400, detail="Missing required field 'input'" ) @@ -597,14 +737,24 @@ async def embed_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) + is_openai_endpoint = "/v1" in endpoint + if is_openai_endpoint: + if ":latest" in model: + model = model.split(":latest") + model = model[0] + client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint]) + else: + client = ollama.AsyncClient(host=endpoint) await increment_usage(endpoint, model) - client = ollama.AsyncClient(host=endpoint) - # 3. Async generator that streams embed data and decrements the counter async def stream_embedding_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) - async_gen = await client.embed(model=model, input=input, truncate=truncate, options=options, keep_alive=keep_alive) + if is_openai_endpoint: + async_gen = await client.embeddings.create(input=[_input], model=model) + async_gen = rechunk.openai_embed2ollama(async_gen, model) + else: + async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive) if hasattr(async_gen, "model_dump_json"): json_line = async_gen.model_dump_json() else: @@ -877,7 +1027,7 @@ async def version_proxy(request: Request): """ # 1. Query all endpoints for version - tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep] + tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep] all_versions = await asyncio.gather(*tasks) def version_key(v): @@ -900,12 +1050,21 @@ async def tags_proxy(request: Request): """ # 1. Query all endpoints for models - tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] - tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] + tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] + tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] all_models = await asyncio.gather(*tasks) models = {'models': []} for modellist in all_models: + for model in modellist: + if not "model" in model.keys(): # Relable OpenAI models with Ollama Model.model from Model.id + model['model'] = model['id'] + ":latest" + else: + model['id'] = model['model'] + if not "name" in model.keys(): # Relable OpenAI models with Ollama Model.name from Model.model to have model,name keys + model['name'] = model['model'] + else: + model['id'] = model['model'] models['models'] += modellist # 2. Return a JSONResponse with a deduplicated list of unique models for inference @@ -924,7 +1083,7 @@ async def ps_proxy(request: Request): """ # 1. Query all endpoints for running models - tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep] + tasks = [fetch.endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep] loaded_models = await asyncio.gather(*tasks) models = {'models': []} @@ -960,22 +1119,22 @@ async def config_proxy(request: Request): """ async def check_endpoint(url: str): try: - async with httpx.AsyncClient(timeout=1, transport=AiohttpTransport()) as client: - if "/v1" in url: - headers = {"Authorization": "Bearer " + config.api_keys[url]} - r = await client.get(f"{url}/models", headers=headers) - else: - r = await client.get(f"{url}/api/version") - r.raise_for_status() - data = r.json() - if "/v1" in url: - return {"url": url, "status": "ok", "version": "latest"} - else: - return {"url": url, "status": "ok", "version": data.get("version")} - except Exception as exc: - return {"url": url, "status": "error", "detail": str(exc)} - finally: - await client.aclose() + client: aiohttp.ClientSession = app_state["session"] + if "/v1" in url: + headers = {"Authorization": "Bearer " + config.api_keys[url]} + async with client.get(f"{url}/models", headers=headers) as resp: + await _ensure_success(resp) + data = await resp.json() + else: + async with client.get(f"{url}/api/version") as resp: + await _ensure_success(resp) + data = await resp.json() + if "/v1" in url: + return {"url": url, "status": "ok", "version": "latest"} + else: + return {"url": url, "status": "ok", "version": data.get("version")} + except Exception as e: + return {"url": url, "status": "error", "detail": str(e)} results = await asyncio.gather(*[check_endpoint(ep) for ep in config.endpoints]) return {"endpoints": results} @@ -995,14 +1154,14 @@ async def openai_embedding_proxy(request: Request): payload = json.loads(body_bytes.decode("utf-8")) model = payload.get("model") - input = payload.get("input") + doc = payload.get("input") if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) - if not input: + if not doc: raise HTTPException( status_code=400, detail="Missing required field 'input'" ) @@ -1016,10 +1175,11 @@ async def openai_embedding_proxy(request: Request): api_key = config.api_keys[endpoint] else: api_key = "ollama" - oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key=api_key) + base_url = ep2base(endpoint) + oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key) # 3. Async generator that streams embedding data and decrements the counter - async_gen = await oclient.embeddings.create(input=[input], model=model) + async_gen = await oclient.embeddings.create(input=[doc], model=model) await decrement_usage(endpoint, model) @@ -1055,33 +1215,31 @@ async def openai_chat_completions_proxy(request: Request): max_completion_tokens = payload.get("max_completion_tokens") tools = payload.get("tools") + if ":latest" in model: + model = model.split(":latest") + model = model[0] + params = { "messages": messages, "model": model, + } + + optional_params = { + "tools": tools, + "response_format": response_format, + "stream_options": stream_options, + "max_completion_tokens": max_completion_tokens, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "seed": seed, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, "stop": stop, "stream": stream, } - if tools is not None: - params["tools"] = tools - if response_format is not None: - params["response_format"] = response_format - if stream_options is not None: - params["stream_options"] = stream_options - if max_completion_tokens is not None: - params["max_completion_tokens"] = max_completion_tokens - if max_tokens is not None: - params["max_tokens"] = max_tokens - if temperature is not None: - params["temperature"] = temperature - if top_p is not None: - params["top_p"] = top_p - if seed is not None: - params["seed"] = seed - if presence_penalty is not None: - params["presence_penalty"] = presence_penalty - if frequency_penalty is not None: - params["frequency_penalty"] = frequency_penalty + params.update({k: v for k, v in optional_params.items() if v is not None}) if not model: raise HTTPException( @@ -1161,23 +1319,31 @@ async def openai_completions_proxy(request: Request): max_completion_tokens = payload.get("max_completion_tokens") suffix = payload.get("suffix") + if ":latest" in model: + model = model.split(":latest") + model = model[0] + params = { "prompt": prompt, "model": model, - "frequency_penalty": frequency_penalty, - "presence_penalty": presence_penalty, - "seed": seed, + } + + optional_params = { + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "seed": seed, "stop": stop, "stream": stream, + "stream_options": stream_options, "temperature": temperature, - "top_p": top_p, + "top_p": top_p, "max_tokens": max_tokens, + "max_completion_tokens": max_completion_tokens, "suffix": suffix } - if stream_options is not None: - params["stream_options"] = stream_options - + params.update({k: v for k, v in optional_params.items() if v is not None}) + if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" @@ -1238,8 +1404,8 @@ async def openai_models_proxy(request: Request): """ # 1. Query all endpoints for models - tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] - tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] + tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] + tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] all_models = await asyncio.gather(*tasks) models = {'data': []} @@ -1289,9 +1455,7 @@ async def health_proxy(request: Request): * The HTTP status code is 200 when everything is healthy, 503 otherwise. """ # Run all health checks in parallel - tasks = [ - fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints - ] + tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints] results = await asyncio.gather(*tasks, return_exceptions=True) @@ -1333,6 +1497,8 @@ async def usage_stream(request: Request): if await request.is_disconnected(): break data = await queue.get() + if data is None: + break # Send the data as a single SSE message yield f"data: {data}\n\n" finally: @@ -1342,7 +1508,7 @@ async def usage_stream(request: Request): return StreamingResponse(event_generator(), media_type="text/event-stream") # ------------------------------------------------------------- -# 28. FastAPI startup event – load configuration +# 28. FastAPI startup/shutdown events # ------------------------------------------------------------- @app.on_event("startup") async def startup_event() -> None: @@ -1350,4 +1516,18 @@ async def startup_event() -> None: # Load YAML config (or use defaults if not present) config = Config.from_yaml(Path("config.yaml")) print(f"Loaded configuration:\n endpoints={config.endpoints},\n " - f"max_concurrent_connections={config.max_concurrent_connections}") \ No newline at end of file + f"max_concurrent_connections={config.max_concurrent_connections}") + + ssl_context = ssl.create_default_context() + connector = aiohttp.TCPConnector(limit=0, limit_per_host=512, ssl=ssl_context) + timeout = aiohttp.ClientTimeout(total=5, connect=5, sock_read=120, sock_connect=5) + session = aiohttp.ClientSession(connector=connector, timeout=timeout) + + app_state["connector"] = connector + app_state["session"] = session + + +@app.on_event("shutdown") +async def shutdown_event() -> None: + await close_all_sse_queues() + await app_state["session"].close() \ No newline at end of file diff --git a/static/index.html b/static/index.html index 01d902a..0bd5399 100644 --- a/static/index.html +++ b/static/index.html @@ -21,9 +21,9 @@ top: 1rem; /* distance from top edge */ right: 1rem; /* distance from right edge */ cursor: pointer; - min-width: 2.5rem; - min-height: 2.5rem; - font-size: 1.5rem; + min-width: 1rem; + min-height: 1rem; + font-size: 1rem; } .tables-wrapper { display: flex;