Initial commit
This commit is contained in:
parent
5b19fbf739
commit
5f1f3f7b57
3 changed files with 782 additions and 0 deletions
8
config.yaml
Normal file
8
config.yaml
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
# config.yaml
|
||||
endpoints:
|
||||
- http://192.168.0.50:11434
|
||||
- http://192.168.0.51:11434
|
||||
- http://192.168.0.52:11434
|
||||
|
||||
# Maximum concurrent connections *per endpoint‑model pair*
|
||||
max_concurrent_connections: 2
|
||||
21
requirements.txt
Normal file
21
requirements.txt
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
annotated-types==0.7.0
|
||||
anyio==4.10.0
|
||||
certifi==2025.8.3
|
||||
click==8.2.1
|
||||
exceptiongroup==1.3.0
|
||||
fastapi==0.116.1
|
||||
h11==0.16.0
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
idna==3.10
|
||||
ollama==0.5.3
|
||||
pydantic==2.11.7
|
||||
pydantic-settings==2.10.1
|
||||
pydantic_core==2.33.2
|
||||
python-dotenv==1.1.1
|
||||
PyYAML==6.0.2
|
||||
sniffio==1.3.1
|
||||
starlette==0.47.2
|
||||
typing-inspection==0.4.1
|
||||
typing_extensions==4.14.1
|
||||
uvicorn==0.35.0
|
||||
753
router.py
Normal file
753
router.py
Normal file
|
|
@ -0,0 +1,753 @@
|
|||
"""
|
||||
title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing
|
||||
author: alpha-nerd-nomyo
|
||||
author_url: https://github.com/nomyo-ai
|
||||
version: 0.1
|
||||
license: AGPL
|
||||
"""
|
||||
# -------------------------------------------------------------
|
||||
import json, random, asyncio, yaml, httpx, ollama
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, List
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from starlette.responses import StreamingResponse, JSONResponse, Response
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from collections import defaultdict
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 1. Configuration loader
|
||||
# -------------------------------------------------------------
|
||||
class Config(BaseSettings):
|
||||
# List of Ollama endpoints
|
||||
endpoints: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"http://localhost:11434",
|
||||
]
|
||||
)
|
||||
# Max concurrent connections per endpoint‑model pair, see OLLAMA_NUM_PARALLEL
|
||||
max_concurrent_connections: int = 1
|
||||
|
||||
class Config:
|
||||
# Load from `config.yaml` first, then from env variables
|
||||
env_prefix = "OLLAMA_PROXY_"
|
||||
yaml_file = Path("config.yaml") # relative to cwd
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: Path) -> "Config":
|
||||
"""Load the YAML file and create the Config instance."""
|
||||
if path.exists():
|
||||
with path.open("r", encoding="utf-8") as fp:
|
||||
data = yaml.safe_load(fp) or {}
|
||||
return cls(**data)
|
||||
return cls()
|
||||
|
||||
# Create the global config object – it will be overwritten on startup
|
||||
config = Config()
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 2. FastAPI application
|
||||
# -------------------------------------------------------------
|
||||
app = FastAPI()
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 3. Global state: per‑endpoint per‑model active connection counters
|
||||
# -------------------------------------------------------------
|
||||
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
usage_lock = asyncio.Lock() # protects access to usage_counts
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 4. Helperfunctions
|
||||
# -------------------------------------------------------------
|
||||
async def fetch_loaded_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/ps and return a set of model names that are currently
|
||||
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
||||
set is returned.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=1.0) as client:
|
||||
resp = await client.get(f"{endpoint}/api/ps")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# The response format is:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
return models
|
||||
except Exception:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
return set()
|
||||
|
||||
async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List[dict]:
|
||||
"""
|
||||
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
||||
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=1.0) as client:
|
||||
resp = await client.get(f"{endpoint}{route}")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
detail = data.get(detail)
|
||||
return detail
|
||||
except Exception:
|
||||
# If anything goes wrong we cannot reply versions
|
||||
return {detail: "N/A"}
|
||||
|
||||
def dedupe_on_keys(dicts, key_fields):
|
||||
"""
|
||||
Helper function to deduplicate endpoint details based on given dict keys.
|
||||
"""
|
||||
seen = set()
|
||||
out = []
|
||||
for d in dicts:
|
||||
# Build a tuple of the values for the chosen keys
|
||||
key = tuple(d.get(k) for k in key_fields)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
out.append(d)
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 5. Endpoint selection logic (respecting the configurable limit)
|
||||
# -------------------------------------------------------------
|
||||
async def choose_endpoint(model: str) -> str:
|
||||
"""
|
||||
Determine which endpoint to use for the given model while respecting the
|
||||
`max_concurrent_connections` per endpoint‑model pair.
|
||||
|
||||
The selection algorithm is as follows:
|
||||
|
||||
1. Find all endpoints that have the model *already loaded* and that still
|
||||
have a free slot (< max_concurrent_connections).
|
||||
2. If none exist, find any endpoint that has a free slot regardless of
|
||||
whether the model is loaded – the endpoint will load the model on demand.
|
||||
3. If all endpoints are at capacity for this model, pick any endpoint
|
||||
arbitrarily – the request will queue on that endpoint.
|
||||
"""
|
||||
# Gather loaded models for all endpoints concurrently
|
||||
tasks = [fetch_loaded_models(ep) for ep in config.endpoints]
|
||||
loaded_sets = await asyncio.gather(*tasks)
|
||||
|
||||
async with usage_lock:
|
||||
# 1️⃣ Endpoints that have the model loaded *and* a free slot
|
||||
loaded_and_free = [
|
||||
ep for ep, models in zip(config.endpoints, loaded_sets)
|
||||
if model in models and usage_counts[ep].get(model, 0) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
if loaded_and_free:
|
||||
# Prefer an endpoint that already hosts the model and has capacity
|
||||
return random.choice(loaded_and_free)
|
||||
|
||||
# 2️⃣ Endpoints that simply have a free slot (model may or may not be loaded)
|
||||
endpoints_with_free_slot = [
|
||||
ep for ep in config.endpoints
|
||||
if usage_counts[ep].get(model, 0) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
if endpoints_with_free_slot:
|
||||
return random.choice(endpoints_with_free_slot)
|
||||
|
||||
# 3️⃣ All endpoints are at capacity – pick any (will queue on that endpoint according to OLLAMA_MAX_QUEUE)
|
||||
return random.choice(config.endpoints)
|
||||
|
||||
async def increment_usage(endpoint: str, model: str) -> None:
|
||||
async with usage_lock:
|
||||
usage_counts[endpoint][model] += 1
|
||||
|
||||
async def decrement_usage(endpoint: str, model: str) -> None:
|
||||
async with usage_lock:
|
||||
# Avoid negative counts
|
||||
current = usage_counts[endpoint].get(model, 0)
|
||||
if current > 0:
|
||||
usage_counts[endpoint][model] = current - 1
|
||||
# Optionally, clean up zero entries
|
||||
if usage_counts[endpoint].get(model, 0) == 0:
|
||||
usage_counts[endpoint].pop(model, None)
|
||||
if not usage_counts[endpoint]:
|
||||
usage_counts.pop(endpoint, None)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 6. API route – Generate
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/generate")
|
||||
async def proxy(request: Request):
|
||||
"""
|
||||
Proxy a generate request to Ollama and stream the response back to the client.
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
prompt = payload.get("prompt")
|
||||
suffix = payload.get("suffix")
|
||||
system = payload.get("system")
|
||||
template = payload.get("template")
|
||||
context = payload.get("context")
|
||||
stream = payload.get("stream")
|
||||
think = payload.get("think")
|
||||
raw = payload.get("raw")
|
||||
format = payload.get("format")
|
||||
images = payload.get("images")
|
||||
options = payload.get("options")
|
||||
keep_alive = payload.get("keep_alive")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'prompt'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Decide which endpoint to use
|
||||
endpoint = await choose_endpoint(model)
|
||||
|
||||
# Increment usage counter for this endpoint‑model pair
|
||||
await increment_usage(endpoint, model)
|
||||
|
||||
# 3. Create Ollama client instance
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 4. Async generator that streams data and decrements the counter
|
||||
async def stream_generate_response():
|
||||
try:
|
||||
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=format, images=images, options=options, keep_alive=keep_alive)
|
||||
if stream == True:
|
||||
async for chunk in async_gen:
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
json_line = json.dumps(chunk)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
else:
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else json.dumps(async_gen)
|
||||
)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, model)
|
||||
|
||||
# 5. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_generate_response(),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 7. API route – Chat
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/chat")
|
||||
async def chat_proxy(request: Request):
|
||||
"""
|
||||
Proxy a chat request to Ollama and stream the endpoint reply.
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
messages = payload.get("messages")
|
||||
tools = payload.get("tools")
|
||||
stream = payload.get("stream")
|
||||
think = payload.get("think")
|
||||
format = payload.get("format")
|
||||
options = payload.get("options")
|
||||
keep_alive = payload.get("keep_alive")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not isinstance(messages, list):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing or invalid 'message' field (must be a list)"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 3. Async generator that streams chat data and decrements the counter
|
||||
async def stream_chat_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
if stream == True:
|
||||
async for chunk in async_gen:
|
||||
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
json_line = json.dumps(chunk)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
else:
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else json.dumps(async_gen)
|
||||
)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_chat_response(),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 8. API route – Embedding
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/embeddings")
|
||||
async def embedding_proxy(request: Request):
|
||||
"""
|
||||
Proxy an embedding request to Ollama and reply with embeddings.
|
||||
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
prompt = payload.get("prompt")
|
||||
options = payload.get("options")
|
||||
keep_alive = payload.get("keep_alive")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'prompt'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 3. Async generator that streams embedding data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
|
||||
if hasattr(async_gen, "model_dump_json"):
|
||||
json_line = async_gen.model_dump_json()
|
||||
else:
|
||||
json_line = json.dumps(async_gen)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, model)
|
||||
|
||||
# 5. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_embedding_response(),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 8. API route – Embed
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/embed")
|
||||
async def embed_proxy(request: Request):
|
||||
"""
|
||||
Proxy an embed request to Ollama and reply with embeddings.
|
||||
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
input = payload.get("input")
|
||||
truncate = payload.get("truncate")
|
||||
options = payload.get("options")
|
||||
keep_alive = payload.get("keep_alive")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not input:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'input'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 3. Async generator that streams embed data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||||
async_gen = await client.embed(model=model, input=input, truncate=truncate, options=options, keep_alive=keep_alive)
|
||||
if hasattr(async_gen, "model_dump_json"):
|
||||
json_line = async_gen.model_dump_json()
|
||||
else:
|
||||
json_line = json.dumps(async_gen)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_embedding_response(),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 9. API route – Create
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/create")
|
||||
async def create_proxy(request: Request):
|
||||
"""
|
||||
Proxy a create request to all Ollama endpoints and reply with deduplicated status.
|
||||
"""
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
quantize = payload.get("quantize")
|
||||
from_ = payload.get("from")
|
||||
files = payload.get("files")
|
||||
adapters = payload.get("adapters")
|
||||
template = payload.get("template")
|
||||
license = payload.get("license")
|
||||
system = payload.get("system")
|
||||
parameters = payload.get("parameters")
|
||||
messages = payload.get("messages")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not from_ and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You need to provide either from_ or files parameter!"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
status_lists = []
|
||||
for endpoint in config.endpoints:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False)
|
||||
status_lists.append(create)
|
||||
|
||||
combined_status = []
|
||||
for status_list in status_lists:
|
||||
combined_status += status_list
|
||||
|
||||
final_status = list(dict.fromkeys(combined_status))
|
||||
|
||||
return dict(final_status)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 10. API route – Show
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/show")
|
||||
async def show_proxy(request: Request):
|
||||
"""
|
||||
Proxy a model show request to Ollama and reply with ShowResponse.
|
||||
|
||||
"""
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# 3. Proxy a simple show request
|
||||
show = await client.show(model=model)
|
||||
|
||||
# 4. Return ShowResponse
|
||||
return show
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 11. API route – Copy
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/copy")
|
||||
async def copy_proxy(request: Request):
|
||||
"""
|
||||
Proxy a model copy request to each Ollama endpoint and reply with Status Code.
|
||||
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
src = payload.get("source")
|
||||
dst = payload.get("destination")
|
||||
|
||||
if not src:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'source'"
|
||||
)
|
||||
if not dst:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'destination'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 3. Iterate over all endpoints to copy the model on each endpoint
|
||||
status_list = []
|
||||
for endpoint in config.endpoints:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
# 4. Proxy a simple copy request
|
||||
copy = await client.copy(source=src, destination=dst)
|
||||
status_list.append(copy.status)
|
||||
|
||||
# 4. Return with 200 OK if all went well, 404 if a single endpoint failed
|
||||
if 404 in status_list:
|
||||
return Response(
|
||||
status_code=404
|
||||
)
|
||||
else:
|
||||
return Response(
|
||||
status_code=200
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 12. API route – Delete
|
||||
# -------------------------------------------------------------
|
||||
@app.delete("/api/delete")
|
||||
async def delete_proxy(request: Request):
|
||||
"""
|
||||
Proxy a model delete request to each Ollama endpoint and reply with Status Code.
|
||||
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Iterate over all endpoints to delete the model on each endpoint
|
||||
status_list = []
|
||||
for endpoint in config.endpoints:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
# 3. Proxy a simple copy request
|
||||
copy = await client.delete(model=model)
|
||||
status_list.append(copy.status)
|
||||
|
||||
# 4. Retrun 200 0K, if a single enpoint fails, respond with 404
|
||||
if 404 in status_list:
|
||||
return Response(
|
||||
status_code=404
|
||||
)
|
||||
else:
|
||||
return Response(
|
||||
status_code=200
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 13. API route – Pull
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/pull")
|
||||
async def pull_proxy(request: Request):
|
||||
"""
|
||||
Proxy a pull request to all Ollama endpoint and report status back.
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
insecure = payload.get("insecure")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Iterate over all endpoints to pull the model
|
||||
status_list = []
|
||||
for endpoint in config.endpoints:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
# 3. Proxy a simple pull request
|
||||
pull = await client.pull(model=model, insecure=insecure, stream=False)
|
||||
status_list.append(pull)
|
||||
|
||||
combined_status = []
|
||||
for status in status_list:
|
||||
combined_status += status
|
||||
|
||||
# 4. Report back a deduplicated status message
|
||||
final_status = list(dict.fromkeys(combined_status))
|
||||
|
||||
return dict(final_status)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 14. API route – Push
|
||||
# -------------------------------------------------------------
|
||||
@app.post("/api/push")
|
||||
async def push_proxy(request: Request):
|
||||
"""
|
||||
Proxy a push request to Ollama and respond the deduplicated Ollama endpoint replies.
|
||||
"""
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
insecure = payload.get("insecure")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Iterate over all endpoints
|
||||
status_list = []
|
||||
for endpoint in config.endpoints:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
# 3. Proxy a simple push request
|
||||
push = await client.push(model=model, insecure=insecure, stream=False)
|
||||
status_list.append(push)
|
||||
|
||||
combined_status = []
|
||||
for status in status_list:
|
||||
combined_status += status
|
||||
|
||||
# 4. Report a deduplicated status
|
||||
final_status = list(dict.fromkeys(combined_status))
|
||||
|
||||
return dict(final_status)
|
||||
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 15. API route – Version
|
||||
# -------------------------------------------------------------
|
||||
@app.get("/api/version")
|
||||
async def version_proxy(request: Request):
|
||||
"""
|
||||
Proxy a version request to Ollama and reply lowest version of all endpoints.
|
||||
|
||||
"""
|
||||
# 1. Query all endpoints for version
|
||||
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints]
|
||||
all_versions = await asyncio.gather(*tasks)
|
||||
|
||||
def version_key(v):
|
||||
return tuple(map(int, v.split('.')))
|
||||
|
||||
# 2. Return a JSONResponse with the min Version of all endpoints to maintain compatibility
|
||||
return JSONResponse(
|
||||
content={"version": str(min(all_versions, key=version_key))},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 16. API route – tags
|
||||
# -------------------------------------------------------------
|
||||
@app.get("/api/tags")
|
||||
async def tags_proxy(request: Request):
|
||||
"""
|
||||
Proxy a tags request to Ollama endpoints and reply with a unique list of all models.
|
||||
|
||||
"""
|
||||
# 1. Query all endpoints for models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints]
|
||||
all_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'models': []}
|
||||
for modellist in all_models:
|
||||
models['models'] += modellist
|
||||
|
||||
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
||||
return JSONResponse(
|
||||
content={"models": dedupe_on_keys(models['models'], ['digest','name'])},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 17. API route – ps
|
||||
# -------------------------------------------------------------
|
||||
@app.get("/api/ps")
|
||||
async def ps_proxy(request: Request):
|
||||
"""
|
||||
Proxy a ps request to all Ollama endpoints and reply a unique list of all running models.
|
||||
|
||||
"""
|
||||
# 1. Query all endpoints for running models
|
||||
tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints]
|
||||
loaded_models = await asyncio.gather(*tasks)
|
||||
|
||||
models = {'models': []}
|
||||
for modellist in loaded_models:
|
||||
models['models'] += modellist
|
||||
|
||||
# 25. Return a JSONResponse with deduplicated currently deployed models
|
||||
return JSONResponse(
|
||||
content={"models": dedupe_on_keys(models['models'], ['digest'])},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 18. FastAPI startup event – load configuration
|
||||
# -------------------------------------------------------------
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
global config
|
||||
# Load YAML config (or use defaults if not present)
|
||||
config = Config.from_yaml(Path("config.yaml"))
|
||||
print(f"Loaded configuration:\n endpoints={config.endpoints},\n "
|
||||
f"max_concurrent_connections={config.max_concurrent_connections}")
|
||||
Loading…
Add table
Add a link
Reference in a new issue