Merge pull request #5 from nomyo-ai/dev-v0.3.x

Dev v0.3.x to main
This commit is contained in:
Alpha Nerd 2025-09-19 13:09:23 +02:00 committed by GitHub
commit caaf26f0fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 414 additions and 218 deletions

3
.gitignore vendored
View file

@ -63,3 +63,6 @@ cython_debug/
# Logfile(s)
*.log
*.sqlite3
# Config
config.yaml

View file

@ -18,9 +18,19 @@ endpoints:
- http://ollama0:11434
- http://ollama1:11434
- http://ollama2:11434
- https://api.openai.com/v1
# Maximum concurrent connections *per endpointmodel pair*
max_concurrent_connections: 2
# API keys for remote endpoints
# Set an environment variable like OPENAI_KEY
# Confirm endpoints are exactly as in endpoints block
api_keys:
"http://192.168.0.50:11434": "ollama"
"http://192.168.0.51:11434": "ollama"
"http://192.168.0.52:11434": "ollama"
"https://api.openai.com/v1": "${OPENAI_KEY}"
```
Run the NOMYO Router in a dedicated virtual environment, install the requirements and run with uvicorn:
@ -30,6 +40,13 @@ python3 -m venv .venv/router
source .venv/router/bin/activate
pip3 install -r requirements.txt
```
on the shell do:
```
export OPENAI_KEY=YOUR_SECRET_API_KEY
```
finally you can
```

View file

@ -3,8 +3,7 @@ endpoints:
- http://192.168.0.50:11434
- http://192.168.0.51:11434
- http://192.168.0.52:11434
#- https://openrouter.ai/api/v1
#- https://api.openai.com/v1
- https://api.openai.com/v1
# Maximum concurrent connections *per endpointmodel pair* (equals to OLLAMA_NUM_PARALLEL)
max_concurrent_connections: 2
@ -16,5 +15,4 @@ api_keys:
"http://192.168.0.50:11434": "ollama"
"http://192.168.0.51:11434": "ollama"
"http://192.168.0.52:11434": "ollama"
#"https://openrouter.ai/api/v1": "${OPENROUTER_KEY}"
#"https://api.openai.com/v1": "${OPENAI_KEY}"
"https://api.openai.com/v1": "${OPENAI_KEY}"

View file

@ -14,8 +14,6 @@ fastapi-sse==1.1.1
frozenlist==1.7.0
h11==0.16.0
httpcore==1.0.9
httpx==0.28.1
httpx-aiohttp==0.1.8
idna==3.10
jiter==0.10.0
multidict==6.6.4

598
router.py
View file

@ -2,17 +2,17 @@
title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing
author: alpha-nerd-nomyo
author_url: https://github.com/nomyo-ai
version: 0.2.2
version: 0.3
license: AGPL
"""
# -------------------------------------------------------------
import json, time, asyncio, yaml, httpx, ollama, openai, os, re
from httpx_aiohttp import AiohttpTransport
import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random
from pathlib import Path
from typing import Dict, Set, List, Optional
from fastapi import FastAPI, Request, HTTPException
from fastapi_sse import sse_handler
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse, RedirectResponse
from pydantic import Field
from pydantic_settings import BaseSettings
@ -33,6 +33,14 @@ _error_cache: dict[str, float] = {}
_subscribers: Set[asyncio.Queue] = set()
_subscribers_lock = asyncio.Lock()
# ------------------------------------------------------------------
# aiohttp Global Sessions
# ------------------------------------------------------------------
app_state = {
"session": None,
"connector": None,
}
# -------------------------------------------------------------
# 1. Configuration loader
# -------------------------------------------------------------
@ -85,7 +93,13 @@ config = Config()
# -------------------------------------------------------------
app = FastAPI()
sse_handler.app = app
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "DELETE"],
allow_headers=["Authorization", "Content-Type"],
)
# -------------------------------------------------------------
# 3. Global state: perendpoint permodel active connection counters
# -------------------------------------------------------------
@ -95,128 +109,115 @@ usage_lock = asyncio.Lock() # protects access to usage_counts
# -------------------------------------------------------------
# 4. Helperfunctions
# -------------------------------------------------------------
aiotimeout = aiohttp.ClientTimeout(total=5)
def _is_fresh(cached_at: float, ttl: int) -> bool:
return (time.time() - cached_at) < ttl
def get_httpx_client(endpoint: str) -> httpx.AsyncClient:
"""
Use persistent connections to request endpoint info for reliable results
in high load situations or saturated endpoints.
"""
return httpx.AsyncClient(
base_url=endpoint,
timeout=httpx.Timeout(5.0, read=5.0, write=None, connect=5.0),
#limits=httpx.Limits(
# max_keepalive_connections=64,
# max_connections=64
#),
transport=AiohttpTransport()
)
async def _ensure_success(resp: aiohttp.ClientResponse) -> None:
if resp.status >= 400:
text = await resp.text()
raise HTTPException(status_code=resp.status, detail=text)
async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
Query <endpoint>/api/tags and return a set of all model names that the
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
every model that is installed on the Ollama instance, regardless of
whether the model is currently loaded into memory.
class fetch:
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
Query <endpoint>/api/tags and return a set of all model names that the
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
every model that is installed on the Ollama instance, regardless of
whether the model is currently loaded into memory.
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
set is returned.
"""
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
set is returned.
"""
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
if endpoint in _models_cache:
models, cached_at = _models_cache[endpoint]
if _is_fresh(cached_at, 300):
return models
if endpoint in _models_cache:
models, cached_at = _models_cache[endpoint]
if _is_fresh(cached_at, 300):
return models
else:
# stale entry drop it
del _models_cache[endpoint]
if endpoint in _error_cache:
if _is_fresh(_error_cache[endpoint], 1):
# Still within the short error TTL pretend nothing is available
return set()
else:
# Error expired remove it
del _error_cache[endpoint]
if "/v1" in endpoint:
endpoint_url = f"{endpoint}/models"
key = "data"
else:
# stale entry drop it
del _models_cache[endpoint]
endpoint_url = f"{endpoint}/api/tags"
key = "models"
client: aiohttp.ClientSession = app_state["session"]
try:
async with client.get(endpoint_url, headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
if endpoint in _error_cache:
if _is_fresh(_error_cache[endpoint], 1):
# Still within the short error TTL pretend nothing is available
items = data.get(key, [])
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
if models:
_models_cache[endpoint] = (models, time.time())
return models
else:
# Empty list treat as “no models”, but still cache for 300s
_models_cache[endpoint] = (models, time.time())
return models
except Exception as e:
# Treat any error as if the endpoint offers no models
print(f"[fetch.available_models] {endpoint} error: {e}")
_error_cache[endpoint] = time.time()
return set()
else:
# Error expired remove it
del _error_cache[endpoint]
try:
client = get_httpx_client(endpoint)
if "/v1" in endpoint:
resp = await client.get(f"/models", headers=headers)
else:
resp = await client.get(f"/api/tags")
resp.raise_for_status()
data = resp.json()
# Expected format:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
if "/v1" in endpoint:
models = {m.get("id") for m in data.get("data", []) if m.get("id")}
else:
async def loaded_models(endpoint: str) -> Set[str]:
"""
Query <endpoint>/api/ps and return a set of model names that are currently
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
set is returned.
"""
client: aiohttp.ClientSession = app_state["session"]
try:
async with client.get(f"{endpoint}/api/ps") as resp:
await _ensure_success(resp)
data = await resp.json()
# The response format is:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
return models
except Exception:
# If anything goes wrong we simply assume the endpoint has no models
return set()
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
"""
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
"""
client: aiohttp.ClientSession = app_state["session"]
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
if models:
_models_cache[endpoint] = (models, time.time())
return models
else:
# Empty list treat as “no models”, but still cache for 300s
_models_cache[endpoint] = (models, time.time())
return models
except Exception as e:
# Treat any error as if the endpoint offers no models
print(f"[fetch_available_models] {endpoint} error: {e}")
_error_cache[endpoint] = time.time()
return set()
finally:
await client.aclose()
async def fetch_loaded_models(endpoint: str) -> Set[str]:
"""
Query <endpoint>/api/ps and return a set of model names that are currently
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
set is returned.
"""
try:
client = get_httpx_client(endpoint)
resp = await client.get(f"/api/ps")
resp.raise_for_status()
data = resp.json()
# The response format is:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
return models
except Exception:
# If anything goes wrong we simply assume the endpoint has no models
return set()
finally:
await client.aclose()
async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]:
"""
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
"""
headers = None
if api_key is not None:
headers = {"Authorization": "Bearer " + api_key}
try:
client = get_httpx_client(endpoint)
resp = await client.get(f"{route}", headers=headers)
resp.raise_for_status()
data = resp.json()
detail = data.get(detail, [])
return detail
except Exception as e:
# If anything goes wrong we cannot reply details
print(e)
return []
finally:
await client.aclose()
try:
async with client.get(f"{endpoint}{route}", headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
detail = data.get(detail, [])
return detail
except Exception as e:
# If anything goes wrong we cannot reply details
print(e)
return []
def ep2base(ep):
if "/v1" in ep:
@ -257,6 +258,71 @@ async def decrement_usage(endpoint: str, model: str) -> None:
# usage_counts.pop(endpoint, None)
await publish_snapshot()
def iso8601_ns():
ns_since_epoch = time.time_ns()
dt = datetime.datetime.fromtimestamp(
ns_since_epoch / 1_000_000_000, # seconds
tz=datetime.timezone.utc
)
iso8601_with_ns = (
dt.strftime("%Y-%m-%dT%H:%M:%S.") + f"{ns_since_epoch % 1_000_000_000:09d}Z"
)
return iso8601_with_ns
class rechunk:
def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float):
rechunk = { "model": chunk.model,
"created_at": iso8601_ns() ,
"done_reason": chunk.choices[0].finish_reason,
"load_duration": None,
"prompt_eval_count": None,
"prompt_eval_duration": None,
"eval_count": None,
"eval_duration": None,
"eval_count": (chunk.usage.completion_tokens if chunk.usage is not None else None),
"prompt_eval_count": (chunk.usage.prompt_tokens if chunk.usage is not None else None),
"eval_duration": (int((time.perf_counter() - start_ts) * 1000) if chunk.usage is not None else None),
"response_token/s": (round(chunk.usage.total_tokens / (time.perf_counter() - start_ts), 2) if chunk.usage is not None else None)
}
if stream == True:
rechunk["message"] = {"role": chunk.choices[0].delta.role or "assistant", "content": chunk.choices[0].delta.content, "thinking": None, "images": None, "tool_name": None, "tool_calls": None}
else:
rechunk["message"] = {"role": chunk.choices[0].message.role or "assistant", "content": chunk.choices[0].message.content, "thinking": None, "images": None, "tool_name": None, "tool_calls": None}
return rechunk
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float):
with_thinking = chunk.choices[0] if chunk.choices[0] else None
thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None
rechunk = { "model": chunk.model,
"created_at": iso8601_ns(),
"load_duration": None,
"done_reason": chunk.choices[0].finish_reason,
"total_duration": None,
"eval_duration": (int((time.perf_counter() - start_ts) * 1000) if chunk.usage is not None else None),
"thinking": thinking,
"context": None,
"response": chunk.choices[0].text
}
return rechunk
def openai_embeddings2ollama(chunk: dict):
rechunk = {"embedding": chunk.data[0].embedding}
return rechunk
def openai_embed2ollama(chunk: dict, model: str):
rechunk = { "model": model,
"created_at": iso8601_ns(),
"done": None,
"done_reason": None,
"total_duration": None,
"load_duration": None,
"prompt_eval_count": None,
"prompt_eval_duration": None,
"eval_count": None,
"eval_duration": None,
"embeddings": [chunk.data[0].embedding]
}
return rechunk
# ------------------------------------------------------------------
# SSE Helpser
# ------------------------------------------------------------------
@ -269,6 +335,11 @@ async def publish_snapshot():
continue
await q.put(snapshot)
async def close_all_sse_queues():
for q in list(_subscribers):
# sentinel value that the generator will recognise
await q.put(None)
# ------------------------------------------------------------------
# Subscriber helpers
# ------------------------------------------------------------------
@ -305,16 +376,17 @@ async def choose_endpoint(model: str) -> str:
1 Query every endpoint for its advertised models (`/api/tags`).
2 Build a list of endpoints that contain the requested model.
3 For those endpoints, find those that have the model loaded
(`/api/ps`) *and* still have a free slot.
(`/api/ps`) *and* still have a free slot.
4 If none are both loaded and free, fall back to any endpoint
from the filtered list that simply has a free slot.
from the filtered list that simply has a free slot and randomly
select one.
5 If all are saturated, pick any endpoint from the filtered list
(the request will queue on that endpoint).
(the request will queue on that endpoint).
6 If no endpoint advertises the model at all, raise an error.
"""
# 1⃣ Gather advertisedmodel sets for all endpoints concurrently
tag_tasks = [fetch_available_models(ep) for ep in config.endpoints if "/v1" not in ep]
tag_tasks += [fetch_available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if "/v1" not in ep]
tag_tasks += [fetch.available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
advertised_sets = await asyncio.gather(*tag_tasks)
# 2⃣ Filter endpoints that advertise the requested model
@ -325,16 +397,24 @@ async def choose_endpoint(model: str) -> str:
# 6
if not candidate_endpoints:
raise RuntimeError(
f"None of the configured endpoints ({', '.join(config.endpoints)}) "
f"advertise the model '{model}'."
)
if ":latest" in model: #ollama naming convention not applicable to openai
model = model.split(":latest")
model = model[0]
candidate_endpoints = [
ep for ep, models in zip(config.endpoints, advertised_sets)
if model in models
]
if not candidate_endpoints:
raise RuntimeError(
f"None of the configured endpoints ({', '.join(config.endpoints)}) "
f"advertise the model '{model}'."
)
# 3⃣ Among the candidates, find those that have the model *loaded*
# (concurrently, but only for the filtered list)
load_tasks = [fetch_loaded_models(ep) for ep in candidate_endpoints]
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks)
async with usage_lock:
# Helper: get current usage count for (endpoint, model)
def current_usage(ep: str) -> int:
@ -343,7 +423,7 @@ async def choose_endpoint(model: str) -> str:
# 3⃣ Endpoints that have the model loaded *and* a free slot
loaded_and_free = [
ep for ep, models in zip(candidate_endpoints, loaded_sets)
if model in models and usage_counts[ep].get(model, 0) < config.max_concurrent_connections
if model in models and usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections
]
if loaded_and_free:
@ -353,12 +433,11 @@ async def choose_endpoint(model: str) -> str:
# 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [
ep for ep in candidate_endpoints
if usage_counts[ep].get(model, 0) < config.max_concurrent_connections
if usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections
]
if endpoints_with_free_slot:
ep = min(endpoints_with_free_slot, key=current_usage)
return ep
return random.choice(endpoints_with_free_slot)
# 5⃣ All candidate endpoints are saturated pick one with lowest usages count (will queue)
ep = min(candidate_endpoints, key=current_usage)
@ -372,7 +451,6 @@ async def proxy(request: Request):
"""
Proxy a generate request to Ollama and stream the response back to the client.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
@ -386,7 +464,7 @@ async def proxy(request: Request):
stream = payload.get("stream")
think = payload.get("think")
raw = payload.get("raw")
format = payload.get("format")
_format = payload.get("format")
images = payload.get("images")
options = payload.get("options")
keep_alive = payload.get("keep_alive")
@ -402,29 +480,53 @@ async def proxy(request: Request):
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Decide which endpoint to use
endpoint = await choose_endpoint(model)
is_openai_endpoint = "/v1" in endpoint
if is_openai_endpoint:
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
"prompt": prompt,
"model": model,
}
# Increment usage counter for this endpointmodel pair
optional_params = {
"stream": stream,
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
# 3. Create Ollama client instance
client = ollama.AsyncClient(host=endpoint)
# 4. Async generator that streams data and decrements the counter
async def stream_generate_response():
try:
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=format, images=images, options=options, keep_alive=keep_alive)
if is_openai_endpoint:
start_ts = time.perf_counter()
async_gen = await oclient.completions.create(**params)
else:
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
if stream == True:
async for chunk in async_gen:
if is_openai_endpoint:
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
json_line = json.dumps(chunk)
yield json_line.encode("utf-8") + b"\n"
else:
if is_openai_endpoint:
response = rechunk.openai_completion2ollama(async_gen, stream, start_ts)
response = json.dumps(response)
else:
response = async_gen.model_dump_json()
json_line = (
async_gen.model_dump_json()
response
if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen)
)
@ -468,23 +570,46 @@ async def chat_proxy(request: Request):
)
if not isinstance(messages, list):
raise HTTPException(
status_code=400, detail="Missing or invalid 'message' field (must be a list)"
status_code=400, detail="Missing or invalid 'messages' field (must be a list)"
)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
client = ollama.AsyncClient(host=endpoint)
is_openai_endpoint = "/v1" in endpoint
if is_openai_endpoint:
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
"messages": messages,
"model": model,
}
optional_params = {
"tools": tools,
"stream": stream,
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
# 3. Async generator that streams chat data and decrements the counter
async def stream_chat_response():
try:
# The chat method returns a generator of dicts (or GenerateResponse)
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
if is_openai_endpoint:
start_ts = time.perf_counter()
async_gen = await oclient.chat.completions.create(**params)
else:
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
if stream == True:
async for chunk in async_gen:
if is_openai_endpoint:
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
# `chunk` can be a dict or a pydantic model dump to JSON safely
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
@ -492,8 +617,13 @@ async def chat_proxy(request: Request):
json_line = json.dumps(chunk)
yield json_line.encode("utf-8") + b"\n"
else:
if is_openai_endpoint:
response = rechunk.openai_chat_completion2ollama(async_gen, stream, start_ts)
response = json.dumps(response)
else:
response = async_gen.model_dump_json()
json_line = (
async_gen.model_dump_json()
response
if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen)
)
@ -541,14 +671,24 @@ async def embedding_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
is_openai_endpoint = "/v1" in endpoint
if is_openai_endpoint:
if ":latest" in model:
model = model.split(":latest")
model = model[0]
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
client = ollama.AsyncClient(host=endpoint)
# 3. Async generator that streams embedding data and decrements the counter
async def stream_embedding_response():
try:
# The chat method returns a generator of dicts (or GenerateResponse)
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
if is_openai_endpoint:
async_gen = await client.embeddings.create(input=[prompt], model=model)
async_gen = rechunk.openai_embeddings2ollama(async_gen)
else:
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
if hasattr(async_gen, "model_dump_json"):
json_line = async_gen.model_dump_json()
else:
@ -579,7 +719,7 @@ async def embed_proxy(request: Request):
payload = json.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
input = payload.get("input")
_input = payload.get("input")
truncate = payload.get("truncate")
options = payload.get("options")
keep_alive = payload.get("keep_alive")
@ -588,7 +728,7 @@ async def embed_proxy(request: Request):
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not input:
if not _input:
raise HTTPException(
status_code=400, detail="Missing required field 'input'"
)
@ -597,14 +737,24 @@ async def embed_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
is_openai_endpoint = "/v1" in endpoint
if is_openai_endpoint:
if ":latest" in model:
model = model.split(":latest")
model = model[0]
client = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint])
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
client = ollama.AsyncClient(host=endpoint)
# 3. Async generator that streams embed data and decrements the counter
async def stream_embedding_response():
try:
# The chat method returns a generator of dicts (or GenerateResponse)
async_gen = await client.embed(model=model, input=input, truncate=truncate, options=options, keep_alive=keep_alive)
if is_openai_endpoint:
async_gen = await client.embeddings.create(input=[_input], model=model)
async_gen = rechunk.openai_embed2ollama(async_gen, model)
else:
async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive)
if hasattr(async_gen, "model_dump_json"):
json_line = async_gen.model_dump_json()
else:
@ -877,7 +1027,7 @@ async def version_proxy(request: Request):
"""
# 1. Query all endpoints for version
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
all_versions = await asyncio.gather(*tasks)
def version_key(v):
@ -900,12 +1050,21 @@ async def tags_proxy(request: Request):
"""
# 1. Query all endpoints for models
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
all_models = await asyncio.gather(*tasks)
models = {'models': []}
for modellist in all_models:
for model in modellist:
if not "model" in model.keys(): # Relable OpenAI models with Ollama Model.model from Model.id
model['model'] = model['id'] + ":latest"
else:
model['id'] = model['model']
if not "name" in model.keys(): # Relable OpenAI models with Ollama Model.name from Model.model to have model,name keys
model['name'] = model['model']
else:
model['id'] = model['model']
models['models'] += modellist
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
@ -924,7 +1083,7 @@ async def ps_proxy(request: Request):
"""
# 1. Query all endpoints for running models
tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep]
tasks = [fetch.endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep]
loaded_models = await asyncio.gather(*tasks)
models = {'models': []}
@ -960,22 +1119,22 @@ async def config_proxy(request: Request):
"""
async def check_endpoint(url: str):
try:
async with httpx.AsyncClient(timeout=1, transport=AiohttpTransport()) as client:
if "/v1" in url:
headers = {"Authorization": "Bearer " + config.api_keys[url]}
r = await client.get(f"{url}/models", headers=headers)
else:
r = await client.get(f"{url}/api/version")
r.raise_for_status()
data = r.json()
if "/v1" in url:
return {"url": url, "status": "ok", "version": "latest"}
else:
return {"url": url, "status": "ok", "version": data.get("version")}
except Exception as exc:
return {"url": url, "status": "error", "detail": str(exc)}
finally:
await client.aclose()
client: aiohttp.ClientSession = app_state["session"]
if "/v1" in url:
headers = {"Authorization": "Bearer " + config.api_keys[url]}
async with client.get(f"{url}/models", headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
else:
async with client.get(f"{url}/api/version") as resp:
await _ensure_success(resp)
data = await resp.json()
if "/v1" in url:
return {"url": url, "status": "ok", "version": "latest"}
else:
return {"url": url, "status": "ok", "version": data.get("version")}
except Exception as e:
return {"url": url, "status": "error", "detail": str(e)}
results = await asyncio.gather(*[check_endpoint(ep) for ep in config.endpoints])
return {"endpoints": results}
@ -995,14 +1154,14 @@ async def openai_embedding_proxy(request: Request):
payload = json.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
input = payload.get("input")
doc = payload.get("input")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not input:
if not doc:
raise HTTPException(
status_code=400, detail="Missing required field 'input'"
)
@ -1016,10 +1175,11 @@ async def openai_embedding_proxy(request: Request):
api_key = config.api_keys[endpoint]
else:
api_key = "ollama"
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key=api_key)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
# 3. Async generator that streams embedding data and decrements the counter
async_gen = await oclient.embeddings.create(input=[input], model=model)
async_gen = await oclient.embeddings.create(input=[doc], model=model)
await decrement_usage(endpoint, model)
@ -1055,33 +1215,31 @@ async def openai_chat_completions_proxy(request: Request):
max_completion_tokens = payload.get("max_completion_tokens")
tools = payload.get("tools")
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
"messages": messages,
"model": model,
}
optional_params = {
"tools": tools,
"response_format": response_format,
"stream_options": stream_options,
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"seed": seed,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"stop": stop,
"stream": stream,
}
if tools is not None:
params["tools"] = tools
if response_format is not None:
params["response_format"] = response_format
if stream_options is not None:
params["stream_options"] = stream_options
if max_completion_tokens is not None:
params["max_completion_tokens"] = max_completion_tokens
if max_tokens is not None:
params["max_tokens"] = max_tokens
if temperature is not None:
params["temperature"] = temperature
if top_p is not None:
params["top_p"] = top_p
if seed is not None:
params["seed"] = seed
if presence_penalty is not None:
params["presence_penalty"] = presence_penalty
if frequency_penalty is not None:
params["frequency_penalty"] = frequency_penalty
params.update({k: v for k, v in optional_params.items() if v is not None})
if not model:
raise HTTPException(
@ -1161,23 +1319,31 @@ async def openai_completions_proxy(request: Request):
max_completion_tokens = payload.get("max_completion_tokens")
suffix = payload.get("suffix")
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
"prompt": prompt,
"model": model,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"seed": seed,
}
optional_params = {
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options,
"temperature": temperature,
"top_p": top_p,
"top_p": top_p,
"max_tokens": max_tokens,
"max_completion_tokens": max_completion_tokens,
"suffix": suffix
}
if stream_options is not None:
params["stream_options"] = stream_options
params.update({k: v for k, v in optional_params.items() if v is not None})
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
@ -1238,8 +1404,8 @@ async def openai_models_proxy(request: Request):
"""
# 1. Query all endpoints for models
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
all_models = await asyncio.gather(*tasks)
models = {'data': []}
@ -1289,9 +1455,7 @@ async def health_proxy(request: Request):
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
"""
# Run all health checks in parallel
tasks = [
fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints
]
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints]
results = await asyncio.gather(*tasks, return_exceptions=True)
@ -1333,6 +1497,8 @@ async def usage_stream(request: Request):
if await request.is_disconnected():
break
data = await queue.get()
if data is None:
break
# Send the data as a single SSE message
yield f"data: {data}\n\n"
finally:
@ -1342,7 +1508,7 @@ async def usage_stream(request: Request):
return StreamingResponse(event_generator(), media_type="text/event-stream")
# -------------------------------------------------------------
# 28. FastAPI startup event load configuration
# 28. FastAPI startup/shutdown events
# -------------------------------------------------------------
@app.on_event("startup")
async def startup_event() -> None:
@ -1350,4 +1516,18 @@ async def startup_event() -> None:
# Load YAML config (or use defaults if not present)
config = Config.from_yaml(Path("config.yaml"))
print(f"Loaded configuration:\n endpoints={config.endpoints},\n "
f"max_concurrent_connections={config.max_concurrent_connections}")
f"max_concurrent_connections={config.max_concurrent_connections}")
ssl_context = ssl.create_default_context()
connector = aiohttp.TCPConnector(limit=0, limit_per_host=512, ssl=ssl_context)
timeout = aiohttp.ClientTimeout(total=5, connect=5, sock_read=120, sock_connect=5)
session = aiohttp.ClientSession(connector=connector, timeout=timeout)
app_state["connector"] = connector
app_state["session"] = session
@app.on_event("shutdown")
async def shutdown_event() -> None:
await close_all_sse_queues()
await app_state["session"].close()

View file

@ -21,9 +21,9 @@
top: 1rem; /* distance from top edge */
right: 1rem; /* distance from right edge */
cursor: pointer;
min-width: 2.5rem;
min-height: 2.5rem;
font-size: 1.5rem;
min-width: 1rem;
min-height: 1rem;
font-size: 1rem;
}
.tables-wrapper {
display: flex;