2025-08-26 18:19:43 +02:00
|
|
|
|
"""
|
|
|
|
|
|
title: NOMYO Router - an Ollama Proxy with Endpoint:Model aware routing
|
|
|
|
|
|
author: alpha-nerd-nomyo
|
|
|
|
|
|
author_url: https://github.com/nomyo-ai
|
|
|
|
|
|
version: 0.1
|
|
|
|
|
|
license: AGPL
|
|
|
|
|
|
"""
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-28 09:40:33 +02:00
|
|
|
|
import json, random, asyncio, yaml, httpx, ollama, openai
|
2025-08-26 18:19:43 +02:00
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
from typing import Dict, Set, List
|
|
|
|
|
|
from fastapi import FastAPI, Request, HTTPException
|
|
|
|
|
|
from starlette.responses import StreamingResponse, JSONResponse, Response
|
|
|
|
|
|
from pydantic import Field
|
|
|
|
|
|
from pydantic_settings import BaseSettings
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 1. Configuration loader
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
class Config(BaseSettings):
|
|
|
|
|
|
# List of Ollama endpoints
|
|
|
|
|
|
endpoints: list[str] = Field(
|
|
|
|
|
|
default_factory=lambda: [
|
|
|
|
|
|
"http://localhost:11434",
|
|
|
|
|
|
]
|
|
|
|
|
|
)
|
|
|
|
|
|
# Max concurrent connections per endpoint‑model pair, see OLLAMA_NUM_PARALLEL
|
|
|
|
|
|
max_concurrent_connections: int = 1
|
|
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
|
# Load from `config.yaml` first, then from env variables
|
|
|
|
|
|
env_prefix = "OLLAMA_PROXY_"
|
|
|
|
|
|
yaml_file = Path("config.yaml") # relative to cwd
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def from_yaml(cls, path: Path) -> "Config":
|
|
|
|
|
|
"""Load the YAML file and create the Config instance."""
|
|
|
|
|
|
if path.exists():
|
|
|
|
|
|
with path.open("r", encoding="utf-8") as fp:
|
|
|
|
|
|
data = yaml.safe_load(fp) or {}
|
|
|
|
|
|
return cls(**data)
|
|
|
|
|
|
return cls()
|
|
|
|
|
|
|
|
|
|
|
|
# Create the global config object – it will be overwritten on startup
|
|
|
|
|
|
config = Config()
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 2. FastAPI application
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 3. Global state: per‑endpoint per‑model active connection counters
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
|
|
|
|
|
usage_lock = asyncio.Lock() # protects access to usage_counts
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 4. Helperfunctions
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
async def fetch_loaded_models(endpoint: str) -> Set[str]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Query <endpoint>/api/ps and return a set of model names that are currently
|
|
|
|
|
|
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
|
|
|
|
|
set is returned.
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with httpx.AsyncClient(timeout=1.0) as client:
|
|
|
|
|
|
resp = await client.get(f"{endpoint}/api/ps")
|
|
|
|
|
|
resp.raise_for_status()
|
|
|
|
|
|
data = resp.json()
|
|
|
|
|
|
# The response format is:
|
|
|
|
|
|
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
|
|
|
|
|
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
|
|
|
|
|
return models
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
# If anything goes wrong we simply assume the endpoint has no models
|
|
|
|
|
|
return set()
|
|
|
|
|
|
|
|
|
|
|
|
async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List[dict]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
|
|
|
|
|
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with httpx.AsyncClient(timeout=1.0) as client:
|
|
|
|
|
|
resp = await client.get(f"{endpoint}{route}")
|
|
|
|
|
|
resp.raise_for_status()
|
|
|
|
|
|
data = resp.json()
|
|
|
|
|
|
detail = data.get(detail)
|
|
|
|
|
|
return detail
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
# If anything goes wrong we cannot reply versions
|
|
|
|
|
|
return {detail: "N/A"}
|
|
|
|
|
|
|
|
|
|
|
|
def dedupe_on_keys(dicts, key_fields):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Helper function to deduplicate endpoint details based on given dict keys.
|
|
|
|
|
|
"""
|
|
|
|
|
|
seen = set()
|
|
|
|
|
|
out = []
|
|
|
|
|
|
for d in dicts:
|
|
|
|
|
|
# Build a tuple of the values for the chosen keys
|
|
|
|
|
|
key = tuple(d.get(k) for k in key_fields)
|
|
|
|
|
|
if key not in seen:
|
|
|
|
|
|
seen.add(key)
|
|
|
|
|
|
out.append(d)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 5. Endpoint selection logic (respecting the configurable limit)
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
async def choose_endpoint(model: str) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Determine which endpoint to use for the given model while respecting the
|
|
|
|
|
|
`max_concurrent_connections` per endpoint‑model pair.
|
|
|
|
|
|
|
|
|
|
|
|
The selection algorithm is as follows:
|
|
|
|
|
|
|
|
|
|
|
|
1. Find all endpoints that have the model *already loaded* and that still
|
|
|
|
|
|
have a free slot (< max_concurrent_connections).
|
|
|
|
|
|
2. If none exist, find any endpoint that has a free slot regardless of
|
|
|
|
|
|
whether the model is loaded – the endpoint will load the model on demand.
|
|
|
|
|
|
3. If all endpoints are at capacity for this model, pick any endpoint
|
|
|
|
|
|
arbitrarily – the request will queue on that endpoint.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# Gather loaded models for all endpoints concurrently
|
|
|
|
|
|
tasks = [fetch_loaded_models(ep) for ep in config.endpoints]
|
|
|
|
|
|
loaded_sets = await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
|
|
async with usage_lock:
|
|
|
|
|
|
# 1️⃣ Endpoints that have the model loaded *and* a free slot
|
|
|
|
|
|
loaded_and_free = [
|
|
|
|
|
|
ep for ep, models in zip(config.endpoints, loaded_sets)
|
|
|
|
|
|
if model in models and usage_counts[ep].get(model, 0) < config.max_concurrent_connections
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
if loaded_and_free:
|
|
|
|
|
|
# Prefer an endpoint that already hosts the model and has capacity
|
|
|
|
|
|
return random.choice(loaded_and_free)
|
|
|
|
|
|
|
|
|
|
|
|
# 2️⃣ Endpoints that simply have a free slot (model may or may not be loaded)
|
|
|
|
|
|
endpoints_with_free_slot = [
|
|
|
|
|
|
ep for ep in config.endpoints
|
|
|
|
|
|
if usage_counts[ep].get(model, 0) < config.max_concurrent_connections
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
if endpoints_with_free_slot:
|
|
|
|
|
|
return random.choice(endpoints_with_free_slot)
|
|
|
|
|
|
|
|
|
|
|
|
# 3️⃣ All endpoints are at capacity – pick any (will queue on that endpoint according to OLLAMA_MAX_QUEUE)
|
|
|
|
|
|
return random.choice(config.endpoints)
|
|
|
|
|
|
|
|
|
|
|
|
async def increment_usage(endpoint: str, model: str) -> None:
|
|
|
|
|
|
async with usage_lock:
|
|
|
|
|
|
usage_counts[endpoint][model] += 1
|
|
|
|
|
|
|
|
|
|
|
|
async def decrement_usage(endpoint: str, model: str) -> None:
|
|
|
|
|
|
async with usage_lock:
|
|
|
|
|
|
# Avoid negative counts
|
|
|
|
|
|
current = usage_counts[endpoint].get(model, 0)
|
|
|
|
|
|
if current > 0:
|
|
|
|
|
|
usage_counts[endpoint][model] = current - 1
|
|
|
|
|
|
# Optionally, clean up zero entries
|
|
|
|
|
|
if usage_counts[endpoint].get(model, 0) == 0:
|
|
|
|
|
|
usage_counts[endpoint].pop(model, None)
|
|
|
|
|
|
if not usage_counts[endpoint]:
|
|
|
|
|
|
usage_counts.pop(endpoint, None)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 6. API route – Generate
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/generate")
|
|
|
|
|
|
async def proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a generate request to Ollama and stream the response back to the client.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
prompt = payload.get("prompt")
|
|
|
|
|
|
suffix = payload.get("suffix")
|
|
|
|
|
|
system = payload.get("system")
|
|
|
|
|
|
template = payload.get("template")
|
|
|
|
|
|
context = payload.get("context")
|
|
|
|
|
|
stream = payload.get("stream")
|
|
|
|
|
|
think = payload.get("think")
|
|
|
|
|
|
raw = payload.get("raw")
|
|
|
|
|
|
format = payload.get("format")
|
|
|
|
|
|
images = payload.get("images")
|
|
|
|
|
|
options = payload.get("options")
|
|
|
|
|
|
keep_alive = payload.get("keep_alive")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not prompt:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'prompt'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Decide which endpoint to use
|
|
|
|
|
|
endpoint = await choose_endpoint(model)
|
|
|
|
|
|
|
|
|
|
|
|
# Increment usage counter for this endpoint‑model pair
|
|
|
|
|
|
await increment_usage(endpoint, model)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Create Ollama client instance
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Async generator that streams data and decrements the counter
|
|
|
|
|
|
async def stream_generate_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=format, images=images, options=options, keep_alive=keep_alive)
|
|
|
|
|
|
if stream == True:
|
|
|
|
|
|
async for chunk in async_gen:
|
|
|
|
|
|
if hasattr(chunk, "model_dump_json"):
|
|
|
|
|
|
json_line = chunk.model_dump_json()
|
|
|
|
|
|
else:
|
|
|
|
|
|
json_line = json.dumps(chunk)
|
|
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
else:
|
|
|
|
|
|
json_line = (
|
|
|
|
|
|
async_gen.model_dump_json()
|
|
|
|
|
|
if hasattr(async_gen, "model_dump_json")
|
|
|
|
|
|
else json.dumps(async_gen)
|
|
|
|
|
|
)
|
|
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
|
|
|
|
|
await decrement_usage(endpoint, model)
|
|
|
|
|
|
|
|
|
|
|
|
# 5. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_generate_response(),
|
|
|
|
|
|
media_type="application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 7. API route – Chat
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/chat")
|
|
|
|
|
|
async def chat_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a chat request to Ollama and stream the endpoint reply.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
messages = payload.get("messages")
|
|
|
|
|
|
tools = payload.get("tools")
|
|
|
|
|
|
stream = payload.get("stream")
|
|
|
|
|
|
think = payload.get("think")
|
|
|
|
|
|
format = payload.get("format")
|
|
|
|
|
|
options = payload.get("options")
|
|
|
|
|
|
keep_alive = payload.get("keep_alive")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not isinstance(messages, list):
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing or invalid 'message' field (must be a list)"
|
|
|
|
|
|
)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
|
|
|
|
|
endpoint = await choose_endpoint(model)
|
|
|
|
|
|
await increment_usage(endpoint, model)
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Async generator that streams chat data and decrements the counter
|
|
|
|
|
|
async def stream_chat_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
|
|
|
|
|
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
|
|
|
|
|
|
if stream == True:
|
|
|
|
|
|
async for chunk in async_gen:
|
|
|
|
|
|
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
|
|
|
|
|
if hasattr(chunk, "model_dump_json"):
|
|
|
|
|
|
json_line = chunk.model_dump_json()
|
|
|
|
|
|
else:
|
|
|
|
|
|
json_line = json.dumps(chunk)
|
|
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
else:
|
|
|
|
|
|
json_line = (
|
|
|
|
|
|
async_gen.model_dump_json()
|
|
|
|
|
|
if hasattr(async_gen, "model_dump_json")
|
|
|
|
|
|
else json.dumps(async_gen)
|
|
|
|
|
|
)
|
|
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
|
|
|
|
|
await decrement_usage(endpoint, model)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_chat_response(),
|
|
|
|
|
|
media_type="application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-28 09:40:33 +02:00
|
|
|
|
# 8. API route – Embedding - deprecated
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/embeddings")
|
|
|
|
|
|
async def embedding_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy an embedding request to Ollama and reply with embeddings.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
prompt = payload.get("prompt")
|
|
|
|
|
|
options = payload.get("options")
|
|
|
|
|
|
keep_alive = payload.get("keep_alive")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not prompt:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'prompt'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
|
|
|
|
|
endpoint = await choose_endpoint(model)
|
|
|
|
|
|
await increment_usage(endpoint, model)
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Async generator that streams embedding data and decrements the counter
|
|
|
|
|
|
async def stream_embedding_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
|
|
|
|
|
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
|
|
|
|
|
|
if hasattr(async_gen, "model_dump_json"):
|
|
|
|
|
|
json_line = async_gen.model_dump_json()
|
|
|
|
|
|
else:
|
|
|
|
|
|
json_line = json.dumps(async_gen)
|
|
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
|
|
|
|
|
await decrement_usage(endpoint, model)
|
|
|
|
|
|
|
|
|
|
|
|
# 5. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_embedding_response(),
|
|
|
|
|
|
media_type="application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 8. API route – Embed
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/embed")
|
|
|
|
|
|
async def embed_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy an embed request to Ollama and reply with embeddings.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
input = payload.get("input")
|
|
|
|
|
|
truncate = payload.get("truncate")
|
|
|
|
|
|
options = payload.get("options")
|
|
|
|
|
|
keep_alive = payload.get("keep_alive")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not input:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'input'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
|
|
|
|
|
endpoint = await choose_endpoint(model)
|
|
|
|
|
|
await increment_usage(endpoint, model)
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Async generator that streams embed data and decrements the counter
|
|
|
|
|
|
async def stream_embedding_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
|
|
|
|
|
async_gen = await client.embed(model=model, input=input, truncate=truncate, options=options, keep_alive=keep_alive)
|
|
|
|
|
|
if hasattr(async_gen, "model_dump_json"):
|
|
|
|
|
|
json_line = async_gen.model_dump_json()
|
|
|
|
|
|
else:
|
|
|
|
|
|
json_line = json.dumps(async_gen)
|
|
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
|
|
|
|
|
await decrement_usage(endpoint, model)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_embedding_response(),
|
|
|
|
|
|
media_type="application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 9. API route – Create
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/create")
|
|
|
|
|
|
async def create_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a create request to all Ollama endpoints and reply with deduplicated status.
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
quantize = payload.get("quantize")
|
|
|
|
|
|
from_ = payload.get("from")
|
|
|
|
|
|
files = payload.get("files")
|
|
|
|
|
|
adapters = payload.get("adapters")
|
|
|
|
|
|
template = payload.get("template")
|
|
|
|
|
|
license = payload.get("license")
|
|
|
|
|
|
system = payload.get("system")
|
|
|
|
|
|
parameters = payload.get("parameters")
|
|
|
|
|
|
messages = payload.get("messages")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not from_ and not files:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="You need to provide either from_ or files parameter!"
|
|
|
|
|
|
)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
status_lists = []
|
|
|
|
|
|
for endpoint in config.endpoints:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False)
|
|
|
|
|
|
status_lists.append(create)
|
|
|
|
|
|
|
|
|
|
|
|
combined_status = []
|
|
|
|
|
|
for status_list in status_lists:
|
|
|
|
|
|
combined_status += status_list
|
|
|
|
|
|
|
|
|
|
|
|
final_status = list(dict.fromkeys(combined_status))
|
|
|
|
|
|
|
|
|
|
|
|
return dict(final_status)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 10. API route – Show
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/show")
|
|
|
|
|
|
async def show_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a model show request to Ollama and reply with ShowResponse.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
|
|
|
|
|
endpoint = await choose_endpoint(model)
|
|
|
|
|
|
await increment_usage(endpoint, model)
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Proxy a simple show request
|
|
|
|
|
|
show = await client.show(model=model)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Return ShowResponse
|
|
|
|
|
|
return show
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 11. API route – Copy
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/copy")
|
|
|
|
|
|
async def copy_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a model copy request to each Ollama endpoint and reply with Status Code.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
src = payload.get("source")
|
|
|
|
|
|
dst = payload.get("destination")
|
|
|
|
|
|
|
|
|
|
|
|
if not src:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'source'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not dst:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'destination'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Iterate over all endpoints to copy the model on each endpoint
|
|
|
|
|
|
status_list = []
|
|
|
|
|
|
for endpoint in config.endpoints:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 4. Proxy a simple copy request
|
|
|
|
|
|
copy = await client.copy(source=src, destination=dst)
|
|
|
|
|
|
status_list.append(copy.status)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Return with 200 OK if all went well, 404 if a single endpoint failed
|
|
|
|
|
|
if 404 in status_list:
|
|
|
|
|
|
return Response(
|
|
|
|
|
|
status_code=404
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return Response(
|
|
|
|
|
|
status_code=200
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 12. API route – Delete
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.delete("/api/delete")
|
|
|
|
|
|
async def delete_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a model delete request to each Ollama endpoint and reply with Status Code.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Iterate over all endpoints to delete the model on each endpoint
|
|
|
|
|
|
status_list = []
|
|
|
|
|
|
for endpoint in config.endpoints:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 3. Proxy a simple copy request
|
|
|
|
|
|
copy = await client.delete(model=model)
|
|
|
|
|
|
status_list.append(copy.status)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Retrun 200 0K, if a single enpoint fails, respond with 404
|
|
|
|
|
|
if 404 in status_list:
|
|
|
|
|
|
return Response(
|
|
|
|
|
|
status_code=404
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return Response(
|
|
|
|
|
|
status_code=200
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 13. API route – Pull
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/pull")
|
|
|
|
|
|
async def pull_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a pull request to all Ollama endpoint and report status back.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
insecure = payload.get("insecure")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Iterate over all endpoints to pull the model
|
|
|
|
|
|
status_list = []
|
|
|
|
|
|
for endpoint in config.endpoints:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 3. Proxy a simple pull request
|
|
|
|
|
|
pull = await client.pull(model=model, insecure=insecure, stream=False)
|
|
|
|
|
|
status_list.append(pull)
|
|
|
|
|
|
|
|
|
|
|
|
combined_status = []
|
|
|
|
|
|
for status in status_list:
|
|
|
|
|
|
combined_status += status
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Report back a deduplicated status message
|
|
|
|
|
|
final_status = list(dict.fromkeys(combined_status))
|
|
|
|
|
|
|
|
|
|
|
|
return dict(final_status)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 14. API route – Push
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/push")
|
|
|
|
|
|
async def push_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a push request to Ollama and respond the deduplicated Ollama endpoint replies.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
insecure = payload.get("insecure")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Iterate over all endpoints
|
|
|
|
|
|
status_list = []
|
|
|
|
|
|
for endpoint in config.endpoints:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 3. Proxy a simple push request
|
|
|
|
|
|
push = await client.push(model=model, insecure=insecure, stream=False)
|
|
|
|
|
|
status_list.append(push)
|
|
|
|
|
|
|
|
|
|
|
|
combined_status = []
|
|
|
|
|
|
for status in status_list:
|
|
|
|
|
|
combined_status += status
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Report a deduplicated status
|
|
|
|
|
|
final_status = list(dict.fromkeys(combined_status))
|
|
|
|
|
|
|
|
|
|
|
|
return dict(final_status)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 15. API route – Version
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.get("/api/version")
|
|
|
|
|
|
async def version_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a version request to Ollama and reply lowest version of all endpoints.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Query all endpoints for version
|
|
|
|
|
|
tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints]
|
|
|
|
|
|
all_versions = await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
|
|
def version_key(v):
|
|
|
|
|
|
return tuple(map(int, v.split('.')))
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Return a JSONResponse with the min Version of all endpoints to maintain compatibility
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
|
content={"version": str(min(all_versions, key=version_key))},
|
|
|
|
|
|
status_code=200,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 16. API route – tags
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.get("/api/tags")
|
|
|
|
|
|
async def tags_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a tags request to Ollama endpoints and reply with a unique list of all models.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Query all endpoints for models
|
|
|
|
|
|
tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints]
|
|
|
|
|
|
all_models = await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
|
|
models = {'models': []}
|
|
|
|
|
|
for modellist in all_models:
|
|
|
|
|
|
models['models'] += modellist
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
|
content={"models": dedupe_on_keys(models['models'], ['digest','name'])},
|
|
|
|
|
|
status_code=200,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 17. API route – ps
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.get("/api/ps")
|
|
|
|
|
|
async def ps_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a ps request to all Ollama endpoints and reply a unique list of all running models.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Query all endpoints for running models
|
|
|
|
|
|
tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints]
|
|
|
|
|
|
loaded_models = await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
|
|
models = {'models': []}
|
|
|
|
|
|
for modellist in loaded_models:
|
|
|
|
|
|
models['models'] += modellist
|
|
|
|
|
|
|
|
|
|
|
|
# 25. Return a JSONResponse with deduplicated currently deployed models
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
|
content={"models": dedupe_on_keys(models['models'], ['digest'])},
|
|
|
|
|
|
status_code=200,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-28 09:40:33 +02:00
|
|
|
|
# 18. API route – OpenAI compatible Embedding
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/v1/embeddings")
|
|
|
|
|
|
async def openai_embedding_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
input = payload.get("input")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not input:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'input'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
|
|
|
|
|
endpoint = await choose_endpoint(model)
|
|
|
|
|
|
await increment_usage(endpoint, model)
|
|
|
|
|
|
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Async generator that streams embedding data and decrements the counter
|
|
|
|
|
|
async_gen = await oclient.embeddings.create(input = [input], model=model)
|
|
|
|
|
|
|
|
|
|
|
|
await decrement_usage(endpoint, model)
|
|
|
|
|
|
|
|
|
|
|
|
# 5. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return async_gen
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 19. API route – OpenAI compatible Chat Completions
|
2025-08-27 09:23:59 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
2025-08-28 09:40:33 +02:00
|
|
|
|
async def openai_chat_completions_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
messages = payload.get("messages")
|
|
|
|
|
|
frequency_penalty = payload.get("frequency_penalty")
|
|
|
|
|
|
presence_penalty = payload.get("presence_penalty")
|
|
|
|
|
|
response_format = payload.get("response_format")
|
|
|
|
|
|
seed = payload.get("seed")
|
|
|
|
|
|
stop = payload.get("stop")
|
|
|
|
|
|
stream = payload.get("stream")
|
|
|
|
|
|
stream_options = payload.get("stream_options")
|
|
|
|
|
|
temperature = payload.get("temperature")
|
|
|
|
|
|
top_p = payload.get("top_p")
|
|
|
|
|
|
max_tokens = payload.get("max_tokens")
|
|
|
|
|
|
tools =payload.get("tools")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not isinstance(messages, list):
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'messages' (must be a list)"
|
|
|
|
|
|
)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
|
|
|
|
|
endpoint = await choose_endpoint(model)
|
|
|
|
|
|
await increment_usage(endpoint, model)
|
|
|
|
|
|
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Async generator that streams completions data and decrements the counter
|
|
|
|
|
|
async def stream_ochat_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
|
|
|
|
|
async_gen = await oclient.chat.completions.create(messages=messages, model=model, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, response_format=response_format, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, tools=tools)
|
|
|
|
|
|
if stream == True:
|
|
|
|
|
|
async for chunk in async_gen:
|
|
|
|
|
|
data = (
|
|
|
|
|
|
chunk.model_dump_json()
|
|
|
|
|
|
if hasattr(chunk, "model_dump_json")
|
|
|
|
|
|
else json.dumps(chunk)
|
|
|
|
|
|
)
|
|
|
|
|
|
yield f"data: {data}\n\n".encode("utf-8")
|
|
|
|
|
|
# Final DONE event
|
|
|
|
|
|
yield b"data: [DONE]\n\n"
|
|
|
|
|
|
else:
|
|
|
|
|
|
json_line = (
|
|
|
|
|
|
async_gen.model_dump_json()
|
|
|
|
|
|
if hasattr(async_gen, "model_dump_json")
|
|
|
|
|
|
else json.dumps(async_gen)
|
|
|
|
|
|
)
|
|
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
|
|
|
|
|
await decrement_usage(endpoint, model)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_ochat_response(),
|
|
|
|
|
|
media_type="application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 20. API route – OpenAI compatible Completions
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-27 09:23:59 +02:00
|
|
|
|
@app.post("/v1/completions")
|
2025-08-28 09:40:33 +02:00
|
|
|
|
async def openai_completions_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = json.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
prompt = payload.get("prompt")
|
|
|
|
|
|
frequency_penalty = payload.get("frequency_penalty")
|
|
|
|
|
|
presence_penalty = payload.get("presence_penalty")
|
|
|
|
|
|
seed = payload.get("seed")
|
|
|
|
|
|
stop = payload.get("stop")
|
|
|
|
|
|
stream = payload.get("stream")
|
|
|
|
|
|
stream_options = payload.get("stream_options")
|
|
|
|
|
|
temperature = payload.get("temperature")
|
|
|
|
|
|
top_p = payload.get("top_p")
|
|
|
|
|
|
max_tokens = payload.get("max_tokens")
|
|
|
|
|
|
suffix =payload.get("suffix")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not prompt:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'prompt'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
|
|
|
|
|
endpoint = await choose_endpoint(model)
|
|
|
|
|
|
await increment_usage(endpoint, model)
|
|
|
|
|
|
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Async generator that streams completions data and decrements the counter
|
|
|
|
|
|
async def stream_ocompletions_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
|
|
|
|
|
async_gen = await oclient.completions.create(model=model, prompt=prompt, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, suffix=suffix)
|
|
|
|
|
|
if stream == True:
|
|
|
|
|
|
async for chunk in async_gen:
|
|
|
|
|
|
data = (
|
|
|
|
|
|
chunk.model_dump_json()
|
|
|
|
|
|
if hasattr(chunk, "model_dump_json")
|
|
|
|
|
|
else json.dumps(chunk)
|
|
|
|
|
|
)
|
|
|
|
|
|
yield f"data: {data}\n\n".encode("utf-8")
|
|
|
|
|
|
# Final DONE event
|
|
|
|
|
|
yield b"data: [DONE]\n\n"
|
|
|
|
|
|
else:
|
|
|
|
|
|
json_line = (
|
|
|
|
|
|
async_gen.model_dump_json()
|
|
|
|
|
|
if hasattr(async_gen, "model_dump_json")
|
|
|
|
|
|
else json.dumps(async_gen)
|
|
|
|
|
|
)
|
|
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
|
|
|
|
|
await decrement_usage(endpoint, model)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_ocompletions_response(),
|
|
|
|
|
|
media_type="application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 21. OpenAI API compatible endpoints #ToDo
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-27 09:23:59 +02:00
|
|
|
|
@app.post("/v1/models")
|
|
|
|
|
|
async def not_implemented_yet(request: Request):
|
|
|
|
|
|
|
|
|
|
|
|
return Response(
|
|
|
|
|
|
status_code=501
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-28 09:40:33 +02:00
|
|
|
|
# 22. FastAPI startup event – load configuration
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
|
|
|
async def startup_event() -> None:
|
|
|
|
|
|
global config
|
|
|
|
|
|
# Load YAML config (or use defaults if not present)
|
|
|
|
|
|
config = Config.from_yaml(Path("config.yaml"))
|
|
|
|
|
|
print(f"Loaded configuration:\n endpoints={config.endpoints},\n "
|
2025-08-28 09:40:33 +02:00
|
|
|
|
f"max_concurrent_connections={config.max_concurrent_connections}")
|