commit
caaf26f0fc
6 changed files with 414 additions and 218 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -63,3 +63,6 @@ cython_debug/
|
|||
# Logfile(s)
|
||||
*.log
|
||||
*.sqlite3
|
||||
|
||||
# Config
|
||||
config.yaml
|
||||
17
README.md
17
README.md
|
|
@ -18,9 +18,19 @@ endpoints:
|
|||
- http://ollama0:11434
|
||||
- http://ollama1:11434
|
||||
- http://ollama2:11434
|
||||
- https://api.openai.com/v1
|
||||
|
||||
# Maximum concurrent connections *per endpoint‑model pair*
|
||||
max_concurrent_connections: 2
|
||||
|
||||
# API keys for remote endpoints
|
||||
# Set an environment variable like OPENAI_KEY
|
||||
# Confirm endpoints are exactly as in endpoints block
|
||||
api_keys:
|
||||
"http://192.168.0.50:11434": "ollama"
|
||||
"http://192.168.0.51:11434": "ollama"
|
||||
"http://192.168.0.52:11434": "ollama"
|
||||
"https://api.openai.com/v1": "${OPENAI_KEY}"
|
||||
```
|
||||
|
||||
Run the NOMYO Router in a dedicated virtual environment, install the requirements and run with uvicorn:
|
||||
|
|
@ -30,6 +40,13 @@ python3 -m venv .venv/router
|
|||
source .venv/router/bin/activate
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
on the shell do:
|
||||
|
||||
```
|
||||
export OPENAI_KEY=YOUR_SECRET_API_KEY
|
||||
```
|
||||
|
||||
finally you can
|
||||
|
||||
```
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@ endpoints:
|
|||
- http://192.168.0.50:11434
|
||||
- http://192.168.0.51:11434
|
||||
- http://192.168.0.52:11434
|
||||
#- https://openrouter.ai/api/v1
|
||||
#- https://api.openai.com/v1
|
||||
- https://api.openai.com/v1
|
||||
|
||||
# Maximum concurrent connections *per endpoint‑model pair* (equals to OLLAMA_NUM_PARALLEL)
|
||||
max_concurrent_connections: 2
|
||||
|
|
@ -16,5 +15,4 @@ api_keys:
|
|||
"http://192.168.0.50:11434": "ollama"
|
||||
"http://192.168.0.51:11434": "ollama"
|
||||
"http://192.168.0.52:11434": "ollama"
|
||||
#"https://openrouter.ai/api/v1": "${OPENROUTER_KEY}"
|
||||
#"https://api.openai.com/v1": "${OPENAI_KEY}"
|
||||
"https://api.openai.com/v1": "${OPENAI_KEY}"
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@ fastapi-sse==1.1.1
|
|||
frozenlist==1.7.0
|
||||
h11==0.16.0
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
httpx-aiohttp==0.1.8
|
||||
idna==3.10
|
||||
jiter==0.10.0
|
||||
multidict==6.6.4
|
||||
|
|
|
|||
598
router.py
598
router.py
|
|
@ -2,17 +2,17 @@
|
|||
title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing
|
||||
author: alpha-nerd-nomyo
|
||||
author_url: https://github.com/nomyo-ai
|
||||
version: 0.2.2
|
||||
version: 0.3
|
||||
license: AGPL
|
||||
"""
|
||||
# -------------------------------------------------------------
|
||||
import json, time, asyncio, yaml, httpx, ollama, openai, os, re
|
||||
from httpx_aiohttp import AiohttpTransport
|
||||
import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, List, Optional
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi_sse import sse_handler
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse, RedirectResponse
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
|
@ -33,6 +33,14 @@ _error_cache: dict[str, float] = {}
|
|||
_subscribers: Set[asyncio.Queue] = set()
|
||||
_subscribers_lock = asyncio.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# aiohttp Global Sessions
|
||||
# ------------------------------------------------------------------
|
||||
app_state = {
|
||||
"session": None,
|
||||
"connector": None,
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 1. Configuration loader
|
||||
# -------------------------------------------------------------
|
||||
|
|
@ -85,7 +93,13 @@ config = Config()
|
|||
# -------------------------------------------------------------
|
||||
app = FastAPI()
|
||||
sse_handler.app = app
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "DELETE"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
# -------------------------------------------------------------
|
||||
# 3. Global state: per‑endpoint per‑model active connection counters
|
||||
# -------------------------------------------------------------
|
||||
|
|
@ -95,128 +109,115 @@ usage_lock = asyncio.Lock() # protects access to usage_counts
|
|||
# -------------------------------------------------------------
|
||||
# 4. Helperfunctions
|
||||
# -------------------------------------------------------------
|
||||
aiotimeout = aiohttp.ClientTimeout(total=5)
|
||||
|
||||
def _is_fresh(cached_at: float, ttl: int) -> bool:
|
||||
return (time.time() - cached_at) < ttl
|
||||
|
||||
def get_httpx_client(endpoint: str) -> httpx.AsyncClient:
|
||||
"""
|
||||
Use persistent connections to request endpoint info for reliable results
|
||||
in high load situations or saturated endpoints.
|
||||
"""
|
||||
return httpx.AsyncClient(
|
||||
base_url=endpoint,
|
||||
timeout=httpx.Timeout(5.0, read=5.0, write=None, connect=5.0),
|
||||
#limits=httpx.Limits(
|
||||
# max_keepalive_connections=64,
|
||||
# max_connections=64
|
||||
#),
|
||||
transport=AiohttpTransport()
|
||||
)
|
||||
async def _ensure_success(resp: aiohttp.ClientResponse) -> None:
|
||||
if resp.status >= 400:
|
||||
text = await resp.text()
|
||||
raise HTTPException(status_code=resp.status, detail=text)
|
||||
|
||||
async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/tags and return a set of all model names that the
|
||||
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
|
||||
every model that is installed on the Ollama instance, regardless of
|
||||
whether the model is currently loaded into memory.
|
||||
class fetch:
|
||||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/tags and return a set of all model names that the
|
||||
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
|
||||
every model that is installed on the Ollama instance, regardless of
|
||||
whether the model is currently loaded into memory.
|
||||
|
||||
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
||||
set is returned.
|
||||
"""
|
||||
headers = None
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": "Bearer " + api_key}
|
||||
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
||||
set is returned.
|
||||
"""
|
||||
headers = None
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": "Bearer " + api_key}
|
||||
|
||||
if endpoint in _models_cache:
|
||||
models, cached_at = _models_cache[endpoint]
|
||||
if _is_fresh(cached_at, 300):
|
||||
return models
|
||||
if endpoint in _models_cache:
|
||||
models, cached_at = _models_cache[endpoint]
|
||||
if _is_fresh(cached_at, 300):
|
||||
return models
|
||||
else:
|
||||
# stale entry – drop it
|
||||
del _models_cache[endpoint]
|
||||
|
||||
if endpoint in _error_cache:
|
||||
if _is_fresh(_error_cache[endpoint], 1):
|
||||
# Still within the short error TTL – pretend nothing is available
|
||||
return set()
|
||||
else:
|
||||
# Error expired – remove it
|
||||
del _error_cache[endpoint]
|
||||
|
||||
if "/v1" in endpoint:
|
||||
endpoint_url = f"{endpoint}/models"
|
||||
key = "data"
|
||||
else:
|
||||
# stale entry – drop it
|
||||
del _models_cache[endpoint]
|
||||
endpoint_url = f"{endpoint}/api/tags"
|
||||
key = "models"
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
try:
|
||||
async with client.get(endpoint_url, headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
|
||||
if endpoint in _error_cache:
|
||||
if _is_fresh(_error_cache[endpoint], 1):
|
||||
# Still within the short error TTL – pretend nothing is available
|
||||
items = data.get(key, [])
|
||||
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
|
||||
|
||||
if models:
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
else:
|
||||
# Empty list – treat as “no models”, but still cache for 300s
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
except Exception as e:
|
||||
# Treat any error as if the endpoint offers no models
|
||||
print(f"[fetch.available_models] {endpoint} error: {e}")
|
||||
_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
else:
|
||||
# Error expired – remove it
|
||||
del _error_cache[endpoint]
|
||||
|
||||
try:
|
||||
client = get_httpx_client(endpoint)
|
||||
if "/v1" in endpoint:
|
||||
resp = await client.get(f"/models", headers=headers)
|
||||
else:
|
||||
resp = await client.get(f"/api/tags")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Expected format:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
if "/v1" in endpoint:
|
||||
models = {m.get("id") for m in data.get("data", []) if m.get("id")}
|
||||
else:
|
||||
|
||||
async def loaded_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/ps and return a set of model names that are currently
|
||||
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
||||
set is returned.
|
||||
"""
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
try:
|
||||
async with client.get(f"{endpoint}/api/ps") as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
# The response format is:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
return models
|
||||
except Exception:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
return set()
|
||||
|
||||
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
||||
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
|
||||
"""
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
headers = None
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": "Bearer " + api_key}
|
||||
|
||||
if models:
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
else:
|
||||
# Empty list – treat as “no models”, but still cache for 300s
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
except Exception as e:
|
||||
# Treat any error as if the endpoint offers no models
|
||||
print(f"[fetch_available_models] {endpoint} error: {e}")
|
||||
_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
async def fetch_loaded_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/ps and return a set of model names that are currently
|
||||
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
||||
set is returned.
|
||||
"""
|
||||
try:
|
||||
client = get_httpx_client(endpoint)
|
||||
resp = await client.get(f"/api/ps")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# The response format is:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
return models
|
||||
except Exception:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
return set()
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
||||
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
|
||||
"""
|
||||
headers = None
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": "Bearer " + api_key}
|
||||
|
||||
try:
|
||||
client = get_httpx_client(endpoint)
|
||||
resp = await client.get(f"{route}", headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
detail = data.get(detail, [])
|
||||
return detail
|
||||
except Exception as e:
|
||||
# If anything goes wrong we cannot reply details
|
||||
print(e)
|
||||
return []
|
||||
finally:
|
||||
await client.aclose()
|
||||
try:
|
||||
async with client.get(f"{endpoint}{route}", headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
detail = data.get(detail, [])
|
||||
return detail
|
||||
except Exception as e:
|
||||
# If anything goes wrong we cannot reply details
|
||||
print(e)
|
||||
return []
|
||||
|
||||
def ep2base(ep):
|
||||
if "/v1" in ep:
|
||||
|
|
@ -257,6 +258,71 @@ async def decrement_usage(endpoint: str, model: str) -> None:
|
|||
# usage_counts.pop(endpoint, None)
|
||||
await publish_snapshot()
|
||||
|
||||
def iso8601_ns():
|
||||
ns_since_epoch = time.time_ns()
|
||||
dt = datetime.datetime.fromtimestamp(
|
||||
ns_since_epoch / 1_000_000_000, # seconds
|
||||
tz=datetime.timezone.utc
|
||||
)
|
||||
iso8601_with_ns = (
|
||||
dt.strftime("%Y-%m-%dT%H:%M:%S.") + f"{ns_since_epoch % 1_000_000_000:09d}Z"
|
||||
)
|
||||
return iso8601_with_ns
|
||||
|
||||
class rechunk:
|
||||
def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float):
|
||||
rechunk = { "model": chunk.model,
|
||||
"created_at": iso8601_ns() ,
|
||||
"done_reason": chunk.choices[0].finish_reason,
|
||||
"load_duration": None,
|
||||
"prompt_eval_count": None,
|
||||
"prompt_eval_duration": None,
|
||||
"eval_count": None,
|
||||
"eval_duration": None,
|
||||
"eval_count": (chunk.usage.completion_tokens if chunk.usage is not None else None),
|
||||
"prompt_eval_count": (chunk.usage.prompt_tokens if chunk.usage is not None else None),
|
||||
"eval_duration": (int((time.perf_counter() - start_ts) * 1000) if chunk.usage is not None else None),
|
||||
"response_token/s": (round(chunk.usage.total_tokens / (time.perf_counter() - start_ts), 2) if chunk.usage is not None else None)
|
||||
}
|
||||
if stream == True:
|
||||
rechunk["message"] = {"role": chunk.choices[0].delta.role or "assistant", "content": chunk.choices[0].delta.content, "thinking": None, "images": None, "tool_name": None, "tool_calls": None}
|
||||
else:
|
||||
rechunk["message"] = {"role": chunk.choices[0].message.role or "assistant", "content": chunk.choices[0].message.content, "thinking": None, "images": None, "tool_name": None, "tool_calls": None}
|
||||
return rechunk
|
||||
|
||||
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float):
|
||||
with_thinking = chunk.choices[0] if chunk.choices[0] else None
|
||||
thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None
|
||||
rechunk = { "model": chunk.model,
|
||||
"created_at": iso8601_ns(),
|
||||
"load_duration": None,
|
||||
"done_reason": chunk.choices[0].finish_reason,
|
||||
"total_duration": None,
|
||||
"eval_duration": (int((time.perf_counter() - start_ts) * 1000) if chunk.usage is not None else None),
|
||||
"thinking": thinking,
|
||||
"context": None,
|
||||
"response": chunk.choices[0].text
|
||||
}
|
||||
return rechunk
|
||||
|
||||
def openai_embeddings2ollama(chunk: dict):
|
||||
rechunk = {"embedding": chunk.data[0].embedding}
|
||||
return rechunk
|
||||
|
||||
def openai_embed2ollama(chunk: dict, model: str):
|
||||
rechunk = { "model": model,
|
||||
"created_at": iso8601_ns(),
|
||||
"done": None,
|
||||
"done_reason": None,
|
||||
"total_duration": None,
|
||||
"load_duration": None,
|
||||
"prompt_eval_count": None,
|
||||
"prompt_eval_duration": None,
|
||||
"eval_count": None,
|
||||
"eval_duration": None,
|
||||
"embeddings": [chunk.data[0].embedding]
|
||||
}
|
||||
return rechunk
|
||||
# ------------------------------------------------------------------
|
||||
# SSE Helpser
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -269,6 +335,11 @@ async def publish_snapshot():
|
|||
continue
|
||||
await q.put(snapshot)
|
||||
|
||||
async def close_all_sse_queues():
|
||||
for q in list(_subscribers):
|
||||
# sentinel value that the generator will recognise
|
||||
await q.put(None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Subscriber helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -305,16 +376,17 @@ async def choose_endpoint(model: str) -> str:
|
|||
1️⃣ Query every endpoint for its advertised models (`/api/tags`).
|
||||
2️⃣ Build a list of endpoints that contain the requested model.
|
||||
3️⃣ For those endpoints, find those that have the model loaded
|
||||
(`/api/ps`) *and* still have a free slot.
|
||||
(`/api/ps`) *and* still have a free slot.
|
||||
4️⃣ If none are both loaded and free, fall back to any endpoint
|
||||
from the filtered list that simply has a free slot.
|
||||
from the filtered list that simply has a free slot and randomly
|
||||
select one.
|
||||
5️⃣ If all are saturated, pick any endpoint from the filtered list
|
||||
(the request will queue on that endpoint).
|
||||
(the request will queue on that endpoint).
|
||||
6️⃣ If no endpoint advertises the model at all, raise an error.
|
||||
"""
|
||||
# 1️⃣ Gather advertised‑model sets for all endpoints concurrently
|
||||
tag_tasks = [fetch_available_models(ep) for ep in config.endpoints if "/v1" not in ep]
|
||||
tag_tasks += [fetch_available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if "/v1" not in ep]
|
||||
tag_tasks += [fetch.available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
advertised_sets = await asyncio.gather(*tag_tasks)
|
||||
|
||||
# 2️⃣ Filter endpoints that advertise the requested model
|
||||
|
|
@ -325,16 +397,24 @@ async def choose_endpoint(model: str) -> str:
|
|||
|
||||
# 6️⃣
|
||||
if not candidate_endpoints:
|
||||
raise RuntimeError(
|
||||
f"None of the configured endpoints ({', '.join(config.endpoints)}) "
|
||||
f"advertise the model '{model}'."
|
||||
)
|
||||
if ":latest" in model: #ollama naming convention not applicable to openai
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
candidate_endpoints = [
|
||||
ep for ep, models in zip(config.endpoints, advertised_sets)
|
||||
if model in models
|
||||
]
|
||||
if not candidate_endpoints:
|
||||
raise RuntimeError(
|
||||
f"None of the configured endpoints ({', '.join(config.endpoints)}) "
|
||||
f"advertise the model '{model}'."
|
||||
)
|
||||
|
||||
# 3️⃣ Among the candidates, find those that have the model *loaded*
|
||||
# (concurrently, but only for the filtered list)
|
||||
load_tasks = [fetch_loaded_models(ep) for ep in candidate_endpoints]
|
||||
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
|
||||
loaded_sets = await asyncio.gather(*load_tasks)
|
||||
|
||||
|
||||
async with usage_lock:
|
||||
# Helper: get current usage count for (endpoint, model)
|
||||
def current_usage(ep: str) -> int:
|
||||
|
|
@ -343,7 +423,7 @@ async def choose_endpoint(model: str) -> str:
|
|||
# 3️⃣ Endpoints that have the model loaded *and* a free slot
|
||||
loaded_and_free = [
|
||||
ep for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
if model in models and usage_counts[ep].get(model, 0) < config.max_concurrent_connections
|
||||
if model in models and usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
if loaded_and_free:
|
||||
|
|
@ -353,12 +433,11 @@ async def choose_endpoint(model: str) -> str:
|
|||
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
||||
endpoints_with_free_slot = [
|
||||
ep for ep in candidate_endpoints
|
||||
if usage_counts[ep].get(model, 0) < config.max_concurrent_connections
|
||||
if usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
if endpoints_with_free_slot:
|
||||
ep = min(endpoints_with_free_slot, key=current_usage)
|
||||
return ep
|
||||
return random.choice(endpoints_with_free_slot)
|
||||
|
||||
# 5️⃣ All candidate endpoints are saturated – pick one with lowest usages count (will queue)
|
||||
ep = min(candidate_endpoints, key=current_usage)
|
||||
|
|
@ -372,7 +451,6 @@ async def proxy(request: Request):
|
|||
"""
|
||||
Proxy a generate request to Ollama and stream the response back to the client.
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
|
@ -386,7 +464,7 @@ async def proxy(request: Request):
|
|||
stream = payload.get("stream")
|
||||
think = payload.get("think")
|
||||
raw = payload.get("raw")
|
||||
format = payload.get("format")
|
||||
_format = payload.get("format")
|
||||
images = payload.get("images")
|
||||
options = payload.get("options")
|
||||
keep_alive = payload.get("keep_alive")
|
||||
|
|
@ -402,29 +480,53 @@ async def proxy(request: Request):
|
|||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Decide which endpoint to use
|
||||
|
||||
endpoint = await choose_endpoint(model)
|
||||
is_openai_endpoint = "/v1" in endpoint
|
||||
if is_openai_endpoint:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
# Increment usage counter for this endpoint‑model pair
|
||||
optional_params = {
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, model)
|
||||
|
||||
# 3. Create Ollama client instance
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 4. Async generator that streams data and decrements the counter
|
||||
async def stream_generate_response():
|
||||
try:
|
||||
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=format, images=images, options=options, keep_alive=keep_alive)
|
||||
if is_openai_endpoint:
|
||||
start_ts = time.perf_counter()
|
||||
async_gen = await oclient.completions.create(**params)
|
||||
else:
|
||||
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
|
||||
if stream == True:
|
||||
async for chunk in async_gen:
|
||||
if is_openai_endpoint:
|
||||
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
json_line = json.dumps(chunk)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
else:
|
||||
if is_openai_endpoint:
|
||||
response = rechunk.openai_completion2ollama(async_gen, stream, start_ts)
|
||||
response = json.dumps(response)
|
||||
else:
|
||||
response = async_gen.model_dump_json()
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
response
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else json.dumps(async_gen)
|
||||
)
|
||||
|
|
@ -468,23 +570,46 @@ async def chat_proxy(request: Request):
|
|||
)
|
||||
if not isinstance(messages, list):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing or invalid 'message' field (must be a list)"
|
||||
status_code=400, detail="Missing or invalid 'messages' field (must be a list)"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
is_openai_endpoint = "/v1" in endpoint
|
||||
if is_openai_endpoint:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
optional_params = {
|
||||
"tools": tools,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, model)
|
||||
# 3. Async generator that streams chat data and decrements the counter
|
||||
async def stream_chat_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
if is_openai_endpoint:
|
||||
start_ts = time.perf_counter()
|
||||
async_gen = await oclient.chat.completions.create(**params)
|
||||
else:
|
||||
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
if stream == True:
|
||||
async for chunk in async_gen:
|
||||
if is_openai_endpoint:
|
||||
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
|
||||
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
|
|
@ -492,8 +617,13 @@ async def chat_proxy(request: Request):
|
|||
json_line = json.dumps(chunk)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
else:
|
||||
if is_openai_endpoint:
|
||||
response = rechunk.openai_chat_completion2ollama(async_gen, stream, start_ts)
|
||||
response = json.dumps(response)
|
||||
else:
|
||||
response = async_gen.model_dump_json()
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
response
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else json.dumps(async_gen)
|
||||
)
|
||||
|
|
@ -541,14 +671,24 @@ async def embedding_proxy(request: Request):
|
|||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
is_openai_endpoint = "/v1" in endpoint
|
||||
if is_openai_endpoint:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, model)
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 3. Async generator that streams embedding data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
|
||||
if is_openai_endpoint:
|
||||
async_gen = await client.embeddings.create(input=[prompt], model=model)
|
||||
async_gen = rechunk.openai_embeddings2ollama(async_gen)
|
||||
else:
|
||||
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
|
||||
if hasattr(async_gen, "model_dump_json"):
|
||||
json_line = async_gen.model_dump_json()
|
||||
else:
|
||||
|
|
@ -579,7 +719,7 @@ async def embed_proxy(request: Request):
|
|||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
input = payload.get("input")
|
||||
_input = payload.get("input")
|
||||
truncate = payload.get("truncate")
|
||||
options = payload.get("options")
|
||||
keep_alive = payload.get("keep_alive")
|
||||
|
|
@ -588,7 +728,7 @@ async def embed_proxy(request: Request):
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not input:
|
||||
if not _input:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'input'"
|
||||
)
|
||||
|
|
@ -597,14 +737,24 @@ async def embed_proxy(request: Request):
|
|||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
is_openai_endpoint = "/v1" in endpoint
|
||||
if is_openai_endpoint:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, model)
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 3. Async generator that streams embed data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
async_gen = await client.embed(model=model, input=input, truncate=truncate, options=options, keep_alive=keep_alive)
|
||||
if is_openai_endpoint:
|
||||
async_gen = await client.embeddings.create(input=[_input], model=model)
|
||||
async_gen = rechunk.openai_embed2ollama(async_gen, model)
|
||||
else:
|
||||
async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive)
|
||||
if hasattr(async_gen, "model_dump_json"):
|
||||
json_line = async_gen.model_dump_json()
|
||||
else:
|
||||
|
|
@ -877,7 +1027,7 @@ async def version_proxy(request: Request):
|
|||
|
||||
"""
|
||||
# 1. Query all endpoints for version
|
||||
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
|
||||
all_versions = await asyncio.gather(*tasks)
|
||||
|
||||
def version_key(v):
|
||||
|
|
@ -900,12 +1050,21 @@ async def tags_proxy(request: Request):
|
|||
"""
|
||||
|
||||
# 1. Query all endpoints for models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
all_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'models': []}
|
||||
for modellist in all_models:
|
||||
for model in modellist:
|
||||
if not "model" in model.keys(): # Relable OpenAI models with Ollama Model.model from Model.id
|
||||
model['model'] = model['id'] + ":latest"
|
||||
else:
|
||||
model['id'] = model['model']
|
||||
if not "name" in model.keys(): # Relable OpenAI models with Ollama Model.name from Model.model to have model,name keys
|
||||
model['name'] = model['model']
|
||||
else:
|
||||
model['id'] = model['model']
|
||||
models['models'] += modellist
|
||||
|
||||
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
||||
|
|
@ -924,7 +1083,7 @@ async def ps_proxy(request: Request):
|
|||
|
||||
"""
|
||||
# 1. Query all endpoints for running models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks = [fetch.endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
loaded_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'models': []}
|
||||
|
|
@ -960,22 +1119,22 @@ async def config_proxy(request: Request):
|
|||
"""
|
||||
async def check_endpoint(url: str):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=1, transport=AiohttpTransport()) as client:
|
||||
if "/v1" in url:
|
||||
headers = {"Authorization": "Bearer " + config.api_keys[url]}
|
||||
r = await client.get(f"{url}/models", headers=headers)
|
||||
else:
|
||||
r = await client.get(f"{url}/api/version")
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "/v1" in url:
|
||||
return {"url": url, "status": "ok", "version": "latest"}
|
||||
else:
|
||||
return {"url": url, "status": "ok", "version": data.get("version")}
|
||||
except Exception as exc:
|
||||
return {"url": url, "status": "error", "detail": str(exc)}
|
||||
finally:
|
||||
await client.aclose()
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
if "/v1" in url:
|
||||
headers = {"Authorization": "Bearer " + config.api_keys[url]}
|
||||
async with client.get(f"{url}/models", headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
else:
|
||||
async with client.get(f"{url}/api/version") as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
if "/v1" in url:
|
||||
return {"url": url, "status": "ok", "version": "latest"}
|
||||
else:
|
||||
return {"url": url, "status": "ok", "version": data.get("version")}
|
||||
except Exception as e:
|
||||
return {"url": url, "status": "error", "detail": str(e)}
|
||||
|
||||
results = await asyncio.gather(*[check_endpoint(ep) for ep in config.endpoints])
|
||||
return {"endpoints": results}
|
||||
|
|
@ -995,14 +1154,14 @@ async def openai_embedding_proxy(request: Request):
|
|||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
input = payload.get("input")
|
||||
doc = payload.get("input")
|
||||
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not input:
|
||||
if not doc:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'input'"
|
||||
)
|
||||
|
|
@ -1016,10 +1175,11 @@ async def openai_embedding_proxy(request: Request):
|
|||
api_key = config.api_keys[endpoint]
|
||||
else:
|
||||
api_key = "ollama"
|
||||
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key=api_key)
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
|
||||
# 3. Async generator that streams embedding data and decrements the counter
|
||||
async_gen = await oclient.embeddings.create(input=[input], model=model)
|
||||
async_gen = await oclient.embeddings.create(input=[doc], model=model)
|
||||
|
||||
await decrement_usage(endpoint, model)
|
||||
|
||||
|
|
@ -1055,33 +1215,31 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
max_completion_tokens = payload.get("max_completion_tokens")
|
||||
tools = payload.get("tools")
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
optional_params = {
|
||||
"tools": tools,
|
||||
"response_format": response_format,
|
||||
"stream_options": stream_options,
|
||||
"max_completion_tokens": max_completion_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"seed": seed,
|
||||
"presence_penalty": presence_penalty,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if tools is not None:
|
||||
params["tools"] = tools
|
||||
if response_format is not None:
|
||||
params["response_format"] = response_format
|
||||
if stream_options is not None:
|
||||
params["stream_options"] = stream_options
|
||||
if max_completion_tokens is not None:
|
||||
params["max_completion_tokens"] = max_completion_tokens
|
||||
if max_tokens is not None:
|
||||
params["max_tokens"] = max_tokens
|
||||
if temperature is not None:
|
||||
params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
params["top_p"] = top_p
|
||||
if seed is not None:
|
||||
params["seed"] = seed
|
||||
if presence_penalty is not None:
|
||||
params["presence_penalty"] = presence_penalty
|
||||
if frequency_penalty is not None:
|
||||
params["frequency_penalty"] = frequency_penalty
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
|
|
@ -1161,23 +1319,31 @@ async def openai_completions_proxy(request: Request):
|
|||
max_completion_tokens = payload.get("max_completion_tokens")
|
||||
suffix = payload.get("suffix")
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"presence_penalty": presence_penalty,
|
||||
"seed": seed,
|
||||
}
|
||||
|
||||
optional_params = {
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"presence_penalty": presence_penalty,
|
||||
"seed": seed,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"stream_options": stream_options,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens,
|
||||
"max_completion_tokens": max_completion_tokens,
|
||||
"suffix": suffix
|
||||
}
|
||||
|
||||
if stream_options is not None:
|
||||
params["stream_options"] = stream_options
|
||||
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
|
|
@ -1238,8 +1404,8 @@ async def openai_models_proxy(request: Request):
|
|||
|
||||
"""
|
||||
# 1. Query all endpoints for models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
||||
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
||||
all_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'data': []}
|
||||
|
|
@ -1289,9 +1455,7 @@ async def health_proxy(request: Request):
|
|||
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
|
||||
"""
|
||||
# Run all health checks in parallel
|
||||
tasks = [
|
||||
fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints
|
||||
]
|
||||
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
|
@ -1333,6 +1497,8 @@ async def usage_stream(request: Request):
|
|||
if await request.is_disconnected():
|
||||
break
|
||||
data = await queue.get()
|
||||
if data is None:
|
||||
break
|
||||
# Send the data as a single SSE message
|
||||
yield f"data: {data}\n\n"
|
||||
finally:
|
||||
|
|
@ -1342,7 +1508,7 @@ async def usage_stream(request: Request):
|
|||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 28. FastAPI startup event – load configuration
|
||||
# 28. FastAPI startup/shutdown events
|
||||
# -------------------------------------------------------------
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
|
|
@ -1350,4 +1516,18 @@ async def startup_event() -> None:
|
|||
# Load YAML config (or use defaults if not present)
|
||||
config = Config.from_yaml(Path("config.yaml"))
|
||||
print(f"Loaded configuration:\n endpoints={config.endpoints},\n "
|
||||
f"max_concurrent_connections={config.max_concurrent_connections}")
|
||||
f"max_concurrent_connections={config.max_concurrent_connections}")
|
||||
|
||||
ssl_context = ssl.create_default_context()
|
||||
connector = aiohttp.TCPConnector(limit=0, limit_per_host=512, ssl=ssl_context)
|
||||
timeout = aiohttp.ClientTimeout(total=5, connect=5, sock_read=120, sock_connect=5)
|
||||
session = aiohttp.ClientSession(connector=connector, timeout=timeout)
|
||||
|
||||
app_state["connector"] = connector
|
||||
app_state["session"] = session
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event() -> None:
|
||||
await close_all_sse_queues()
|
||||
await app_state["session"].close()
|
||||
|
|
@ -21,9 +21,9 @@
|
|||
top: 1rem; /* distance from top edge */
|
||||
right: 1rem; /* distance from right edge */
|
||||
cursor: pointer;
|
||||
min-width: 2.5rem;
|
||||
min-height: 2.5rem;
|
||||
font-size: 1.5rem;
|
||||
min-width: 1rem;
|
||||
min-height: 1rem;
|
||||
font-size: 1rem;
|
||||
}
|
||||
.tables-wrapper {
|
||||
display: flex;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue