various performance improvements and json replacement orjson

This commit is contained in:
Alpha Nerd 2025-11-10 15:37:46 +01:00
parent c6c1059ede
commit 1427e98e6d
2 changed files with 70 additions and 62 deletions

View file

@ -1,6 +1,7 @@
aiohappyeyeballs==2.6.1
aiohttp==3.12.15
aiosignal==1.4.0
annotated-doc==0.0.3
annotated-types==0.7.0
anyio==4.10.0
async-timeout==5.0.1
@ -20,6 +21,7 @@ jiter==0.10.0
multidict==6.6.4
ollama==0.6.0
openai==1.102.0
orjson==3.11.4
pillow==11.3.0
propcache==0.3.2
pydantic==2.11.7

130
router.py
View file

@ -6,7 +6,7 @@ version: 0.4
license: AGPL
"""
# -------------------------------------------------------------
import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io
from pathlib import Path
from typing import Dict, Set, List, Optional
from urllib.parse import urlparse
@ -30,10 +30,11 @@ _models_cache: dict[str, tuple[Set[str], float]] = {}
_error_cache: dict[str, float] = {}
# ------------------------------------------------------------------
# SSE Queues
# Queues
# ------------------------------------------------------------------
_subscribers: Set[asyncio.Queue] = set()
_subscribers_lock = asyncio.Lock()
token_queue: asyncio.Queue[tuple[str, str, int, int]] = asyncio.Queue()
# ------------------------------------------------------------------
# aiohttp Global Sessions
@ -125,6 +126,7 @@ default_headers={
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
usage_lock = asyncio.Lock() # protects access to usage_counts
token_usage_lock = asyncio.Lock()
# -------------------------------------------------------------
# 4. Helperfunctions
@ -192,12 +194,12 @@ def is_ext_openai_endpoint(endpoint: str) -> bool:
return True # It's an external OpenAI endpoint
def record_token_usage(endpoint: str, model: str, prompt: int = 0, completion: int = 0) -> None:
async def _record():
async with usage_lock: # reuse the same lock that protects usage_counts
token_usage_counts[endpoint][model] += (prompt + completion)
await publish_snapshot() # immediately broadcast the new totals
asyncio.create_task(_record())
async def token_worker() -> None:
while True:
endpoint, model, prompt, comp = await token_queue.get()
async with token_usage_lock:
token_usage_counts[endpoint][model] += (prompt + comp)
await publish_snapshot()
class fetch:
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
@ -267,6 +269,8 @@ class fetch:
set is returned.
"""
client: aiohttp.ClientSession = app_state["session"]
if is_ext_openai_endpoint(endpoint):
return set()
try:
async with client.get(f"{endpoint}/api/ps") as resp:
await _ensure_success(resp)
@ -428,18 +432,19 @@ def transform_images_to_data_urls(message_list):
class rechunk:
def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse:
now = time.perf_counter()
if chunk.choices == [] and chunk.usage is not None:
return ollama.ChatResponse(
model=chunk.model,
created_at=iso8601_ns(),
done=True,
done_reason='stop',
total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000),
total_duration=int((now - start_ts) * 1_000_000_000),
load_duration=100000,
prompt_eval_count=int(chunk.usage.prompt_tokens),
prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
eval_count=int(chunk.usage.completion_tokens),
eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000),
eval_duration=int((now - start_ts) * 1_000_000_000),
message={"role": "assistant"}
)
with_thinking = chunk.choices[0] if chunk.choices[0] else None
@ -463,16 +468,17 @@ class rechunk:
created_at=iso8601_ns(),
done=True if chunk.usage is not None else False,
done_reason=chunk.choices[0].finish_reason, #if chunk.choices[0].finish_reason is not None else None,
total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
load_duration=100000,
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
message=assistant_msg)
return rechunk
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.GenerateResponse:
now = time.perf_counter()
with_thinking = chunk.choices[0] if chunk.choices[0] else None
thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None
rechunk = ollama.GenerateResponse(
@ -480,12 +486,12 @@ class rechunk:
created_at=iso8601_ns(),
done=True if chunk.usage is not None else False,
done_reason=chunk.choices[0].finish_reason,
total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
load_duration=10000,
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
response=chunk.choices[0].text or '',
thinking=thinking)
return rechunk
@ -514,9 +520,9 @@ class rechunk:
# ------------------------------------------------------------------
async def publish_snapshot():
async with usage_lock:
snapshot = json.dumps({"usage_counts": usage_counts,
snapshot = orjson.dumps({"usage_counts": usage_counts,
"token_usage_counts": token_usage_counts,
}, sort_keys=True)
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
async with _subscribers_lock:
for q in _subscribers:
# If the queue is full, drop the message to avoid backpressure.
@ -650,7 +656,7 @@ async def proxy(request: Request):
"""
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
prompt = payload.get("prompt")
@ -674,7 +680,7 @@ async def proxy(request: Request):
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted."
raise HTTPException(status_code=400, detail=error_msg) from e
@ -721,11 +727,11 @@ async def proxy(request: Request):
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
record_token_usage(endpoint, model, prompt_tok, comp_tok)
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
json_line = json.dumps(chunk)
json_line = orjson.dumps(chunk)
yield json_line.encode("utf-8") + b"\n"
else:
if is_openai_endpoint:
@ -735,11 +741,11 @@ async def proxy(request: Request):
response = async_gen.model_dump_json()
prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0
record_token_usage(endpoint, model, prompt_tok, comp_tok)
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
json_line = (
response
if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen)
else orjson.dumps(async_gen)
)
yield json_line.encode("utf-8") + b"\n"
@ -764,7 +770,7 @@ async def chat_proxy(request: Request):
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
messages = payload.get("messages")
@ -787,7 +793,7 @@ async def chat_proxy(request: Request):
raise HTTPException(
status_code=400, detail="`options` must be a JSON object"
)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
@ -837,11 +843,11 @@ async def chat_proxy(request: Request):
# `chunk` can be a dict or a pydantic model dump to JSON safely
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
record_token_usage(endpoint, model, prompt_tok, comp_tok)
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
json_line = json.dumps(chunk)
json_line = orjson.dumps(chunk)
yield json_line.encode("utf-8") + b"\n"
else:
if is_openai_endpoint:
@ -851,11 +857,11 @@ async def chat_proxy(request: Request):
response = async_gen.model_dump_json()
prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0
record_token_usage(endpoint, model, prompt_tok, comp_tok)
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
json_line = (
response
if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen)
else orjson.dumps(async_gen)
)
yield json_line.encode("utf-8") + b"\n"
@ -882,7 +888,7 @@ async def embedding_proxy(request: Request):
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
prompt = payload.get("prompt")
@ -897,7 +903,7 @@ async def embedding_proxy(request: Request):
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
@ -923,7 +929,7 @@ async def embedding_proxy(request: Request):
if hasattr(async_gen, "model_dump_json"):
json_line = async_gen.model_dump_json()
else:
json_line = json.dumps(async_gen)
json_line = orjson.dumps(async_gen)
yield json_line.encode("utf-8") + b"\n"
finally:
# Ensure counter is decremented even if an exception occurs
@ -947,7 +953,7 @@ async def embed_proxy(request: Request):
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
_input = payload.get("input")
@ -963,7 +969,7 @@ async def embed_proxy(request: Request):
raise HTTPException(
status_code=400, detail="Missing required field 'input'"
)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
@ -989,7 +995,7 @@ async def embed_proxy(request: Request):
if hasattr(async_gen, "model_dump_json"):
json_line = async_gen.model_dump_json()
else:
json_line = json.dumps(async_gen)
json_line = orjson.dumps(async_gen)
yield json_line.encode("utf-8") + b"\n"
finally:
# Ensure counter is decremented even if an exception occurs
@ -1011,7 +1017,7 @@ async def create_proxy(request: Request):
"""
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
quantize = payload.get("quantize")
@ -1032,7 +1038,7 @@ async def create_proxy(request: Request):
raise HTTPException(
status_code=400, detail="You need to provide either from_ or files parameter!"
)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
status_lists = []
@ -1062,14 +1068,14 @@ async def show_proxy(request: Request, model: Optional[str] = None):
body_bytes = await request.body()
if not model:
payload = json.loads(body_bytes.decode("utf-8"))
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
@ -1097,7 +1103,7 @@ async def copy_proxy(request: Request, source: Optional[str] = None, destination
body_bytes = await request.body()
if not source and not destination:
payload = json.loads(body_bytes.decode("utf-8"))
payload = orjson.loads(body_bytes.decode("utf-8"))
src = payload.get("source")
dst = payload.get("destination")
else:
@ -1112,7 +1118,7 @@ async def copy_proxy(request: Request, source: Optional[str] = None, destination
raise HTTPException(
status_code=400, detail="Missing required field 'destination'"
)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 3. Iterate over all endpoints to copy the model on each endpoint
@ -1141,14 +1147,14 @@ async def delete_proxy(request: Request, model: Optional[str] = None):
body_bytes = await request.body()
if not model:
payload = json.loads(body_bytes.decode("utf-8"))
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints to delete the model on each endpoint
@ -1176,7 +1182,7 @@ async def pull_proxy(request: Request, model: Optional[str] = None):
body_bytes = await request.body()
if not model:
payload = json.loads(body_bytes.decode("utf-8"))
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
insecure = payload.get("insecure")
else:
@ -1186,7 +1192,7 @@ async def pull_proxy(request: Request, model: Optional[str] = None):
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints to pull the model
@ -1218,7 +1224,7 @@ async def push_proxy(request: Request):
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
insecure = payload.get("insecure")
@ -1227,7 +1233,7 @@ async def push_proxy(request: Request):
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints
@ -1385,7 +1391,7 @@ async def openai_embedding_proxy(request: Request):
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
doc = payload.get("input")
@ -1399,7 +1405,7 @@ async def openai_embedding_proxy(request: Request):
raise HTTPException(
status_code=400, detail="Missing required field 'input'"
)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
@ -1432,7 +1438,7 @@ async def openai_chat_completions_proxy(request: Request):
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
messages = payload.get("messages")
@ -1483,7 +1489,7 @@ async def openai_chat_completions_proxy(request: Request):
raise HTTPException(
status_code=400, detail="Missing required field 'messages' (must be a list)"
)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
@ -1501,7 +1507,7 @@ async def openai_chat_completions_proxy(request: Request):
data = (
chunk.model_dump_json()
if hasattr(chunk, "model_dump_json")
else json.dumps(chunk)
else orjson.dumps(chunk)
)
if chunk.choices[0].delta.content is not None:
yield f"data: {data}\n\n".encode("utf-8")
@ -1509,11 +1515,11 @@ async def openai_chat_completions_proxy(request: Request):
else:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok)
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen)
else orjson.dumps(async_gen)
)
yield json_line.encode("utf-8") + b"\n"
@ -1539,7 +1545,7 @@ async def openai_completions_proxy(request: Request):
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
prompt = payload.get("prompt")
@ -1588,7 +1594,7 @@ async def openai_completions_proxy(request: Request):
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
except json.JSONDecodeError as e:
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
@ -1607,7 +1613,7 @@ async def openai_completions_proxy(request: Request):
data = (
chunk.model_dump_json()
if hasattr(chunk, "model_dump_json")
else json.dumps(chunk)
else orjson.dumps(chunk)
)
yield f"data: {data}\n\n".encode("utf-8")
# Final DONE event
@ -1615,11 +1621,11 @@ async def openai_completions_proxy(request: Request):
else:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok)
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen)
else orjson.dumps(async_gen)
)
yield json_line.encode("utf-8") + b"\n"
@ -1774,7 +1780,7 @@ async def startup_event() -> None:
app_state["connector"] = connector
app_state["session"] = session
asyncio.create_task(token_worker())
@app.on_event("shutdown")
async def shutdown_event() -> None: