various performance improvements and json replacement orjson

This commit is contained in:
Alpha Nerd 2025-11-10 15:37:46 +01:00
parent c6c1059ede
commit 1427e98e6d
2 changed files with 70 additions and 62 deletions

View file

@ -1,6 +1,7 @@
aiohappyeyeballs==2.6.1 aiohappyeyeballs==2.6.1
aiohttp==3.12.15 aiohttp==3.12.15
aiosignal==1.4.0 aiosignal==1.4.0
annotated-doc==0.0.3
annotated-types==0.7.0 annotated-types==0.7.0
anyio==4.10.0 anyio==4.10.0
async-timeout==5.0.1 async-timeout==5.0.1
@ -20,6 +21,7 @@ jiter==0.10.0
multidict==6.6.4 multidict==6.6.4
ollama==0.6.0 ollama==0.6.0
openai==1.102.0 openai==1.102.0
orjson==3.11.4
pillow==11.3.0 pillow==11.3.0
propcache==0.3.2 propcache==0.3.2
pydantic==2.11.7 pydantic==2.11.7

130
router.py
View file

@ -6,7 +6,7 @@ version: 0.4
license: AGPL license: AGPL
""" """
# ------------------------------------------------------------- # -------------------------------------------------------------
import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io
from pathlib import Path from pathlib import Path
from typing import Dict, Set, List, Optional from typing import Dict, Set, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
@ -30,10 +30,11 @@ _models_cache: dict[str, tuple[Set[str], float]] = {}
_error_cache: dict[str, float] = {} _error_cache: dict[str, float] = {}
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# SSE Queues # Queues
# ------------------------------------------------------------------ # ------------------------------------------------------------------
_subscribers: Set[asyncio.Queue] = set() _subscribers: Set[asyncio.Queue] = set()
_subscribers_lock = asyncio.Lock() _subscribers_lock = asyncio.Lock()
token_queue: asyncio.Queue[tuple[str, str, int, int]] = asyncio.Queue()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# aiohttp Global Sessions # aiohttp Global Sessions
@ -125,6 +126,7 @@ default_headers={
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
usage_lock = asyncio.Lock() # protects access to usage_counts usage_lock = asyncio.Lock() # protects access to usage_counts
token_usage_lock = asyncio.Lock()
# ------------------------------------------------------------- # -------------------------------------------------------------
# 4. Helperfunctions # 4. Helperfunctions
@ -192,12 +194,12 @@ def is_ext_openai_endpoint(endpoint: str) -> bool:
return True # It's an external OpenAI endpoint return True # It's an external OpenAI endpoint
def record_token_usage(endpoint: str, model: str, prompt: int = 0, completion: int = 0) -> None: async def token_worker() -> None:
async def _record(): while True:
async with usage_lock: # reuse the same lock that protects usage_counts endpoint, model, prompt, comp = await token_queue.get()
token_usage_counts[endpoint][model] += (prompt + completion) async with token_usage_lock:
await publish_snapshot() # immediately broadcast the new totals token_usage_counts[endpoint][model] += (prompt + comp)
asyncio.create_task(_record()) await publish_snapshot()
class fetch: class fetch:
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
@ -267,6 +269,8 @@ class fetch:
set is returned. set is returned.
""" """
client: aiohttp.ClientSession = app_state["session"] client: aiohttp.ClientSession = app_state["session"]
if is_ext_openai_endpoint(endpoint):
return set()
try: try:
async with client.get(f"{endpoint}/api/ps") as resp: async with client.get(f"{endpoint}/api/ps") as resp:
await _ensure_success(resp) await _ensure_success(resp)
@ -428,18 +432,19 @@ def transform_images_to_data_urls(message_list):
class rechunk: class rechunk:
def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse: def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse:
now = time.perf_counter()
if chunk.choices == [] and chunk.usage is not None: if chunk.choices == [] and chunk.usage is not None:
return ollama.ChatResponse( return ollama.ChatResponse(
model=chunk.model, model=chunk.model,
created_at=iso8601_ns(), created_at=iso8601_ns(),
done=True, done=True,
done_reason='stop', done_reason='stop',
total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000), total_duration=int((now - start_ts) * 1_000_000_000),
load_duration=100000, load_duration=100000,
prompt_eval_count=int(chunk.usage.prompt_tokens), prompt_eval_count=int(chunk.usage.prompt_tokens),
prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)), prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
eval_count=int(chunk.usage.completion_tokens), eval_count=int(chunk.usage.completion_tokens),
eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000), eval_duration=int((now - start_ts) * 1_000_000_000),
message={"role": "assistant"} message={"role": "assistant"}
) )
with_thinking = chunk.choices[0] if chunk.choices[0] else None with_thinking = chunk.choices[0] if chunk.choices[0] else None
@ -463,16 +468,17 @@ class rechunk:
created_at=iso8601_ns(), created_at=iso8601_ns(),
done=True if chunk.usage is not None else False, done=True if chunk.usage is not None else False,
done_reason=chunk.choices[0].finish_reason, #if chunk.choices[0].finish_reason is not None else None, done_reason=chunk.choices[0].finish_reason, #if chunk.choices[0].finish_reason is not None else None,
total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
load_duration=100000, load_duration=100000,
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0, prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0, prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0, eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
message=assistant_msg) message=assistant_msg)
return rechunk return rechunk
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.GenerateResponse: def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.GenerateResponse:
now = time.perf_counter()
with_thinking = chunk.choices[0] if chunk.choices[0] else None with_thinking = chunk.choices[0] if chunk.choices[0] else None
thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None
rechunk = ollama.GenerateResponse( rechunk = ollama.GenerateResponse(
@ -480,12 +486,12 @@ class rechunk:
created_at=iso8601_ns(), created_at=iso8601_ns(),
done=True if chunk.usage is not None else False, done=True if chunk.usage is not None else False,
done_reason=chunk.choices[0].finish_reason, done_reason=chunk.choices[0].finish_reason,
total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
load_duration=10000, load_duration=10000,
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0, prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0, prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0, eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
response=chunk.choices[0].text or '', response=chunk.choices[0].text or '',
thinking=thinking) thinking=thinking)
return rechunk return rechunk
@ -514,9 +520,9 @@ class rechunk:
# ------------------------------------------------------------------ # ------------------------------------------------------------------
async def publish_snapshot(): async def publish_snapshot():
async with usage_lock: async with usage_lock:
snapshot = json.dumps({"usage_counts": usage_counts, snapshot = orjson.dumps({"usage_counts": usage_counts,
"token_usage_counts": token_usage_counts, "token_usage_counts": token_usage_counts,
}, sort_keys=True) }, option=orjson.OPT_SORT_KEYS).decode("utf-8")
async with _subscribers_lock: async with _subscribers_lock:
for q in _subscribers: for q in _subscribers:
# If the queue is full, drop the message to avoid backpressure. # If the queue is full, drop the message to avoid backpressure.
@ -650,7 +656,7 @@ async def proxy(request: Request):
""" """
try: try:
body_bytes = await request.body() body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8")) payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model") model = payload.get("model")
prompt = payload.get("prompt") prompt = payload.get("prompt")
@ -674,7 +680,7 @@ async def proxy(request: Request):
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing required field 'prompt'" status_code=400, detail="Missing required field 'prompt'"
) )
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted." error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted."
raise HTTPException(status_code=400, detail=error_msg) from e raise HTTPException(status_code=400, detail=error_msg) from e
@ -721,11 +727,11 @@ async def proxy(request: Request):
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts) chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
prompt_tok = chunk.prompt_eval_count or 0 prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0 comp_tok = chunk.eval_count or 0
record_token_usage(endpoint, model, prompt_tok, comp_tok) await token_queue.put((endpoint, model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"): if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json() json_line = chunk.model_dump_json()
else: else:
json_line = json.dumps(chunk) json_line = orjson.dumps(chunk)
yield json_line.encode("utf-8") + b"\n" yield json_line.encode("utf-8") + b"\n"
else: else:
if is_openai_endpoint: if is_openai_endpoint:
@ -735,11 +741,11 @@ async def proxy(request: Request):
response = async_gen.model_dump_json() response = async_gen.model_dump_json()
prompt_tok = async_gen.prompt_eval_count or 0 prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0 comp_tok = async_gen.eval_count or 0
record_token_usage(endpoint, model, prompt_tok, comp_tok) await token_queue.put((endpoint, model, prompt_tok, comp_tok))
json_line = ( json_line = (
response response
if hasattr(async_gen, "model_dump_json") if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen) else orjson.dumps(async_gen)
) )
yield json_line.encode("utf-8") + b"\n" yield json_line.encode("utf-8") + b"\n"
@ -764,7 +770,7 @@ async def chat_proxy(request: Request):
# 1. Parse and validate request # 1. Parse and validate request
try: try:
body_bytes = await request.body() body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8")) payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model") model = payload.get("model")
messages = payload.get("messages") messages = payload.get("messages")
@ -787,7 +793,7 @@ async def chat_proxy(request: Request):
raise HTTPException( raise HTTPException(
status_code=400, detail="`options` must be a JSON object" status_code=400, detail="`options` must be a JSON object"
) )
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic # 2. Endpoint logic
@ -837,11 +843,11 @@ async def chat_proxy(request: Request):
# `chunk` can be a dict or a pydantic model dump to JSON safely # `chunk` can be a dict or a pydantic model dump to JSON safely
prompt_tok = chunk.prompt_eval_count or 0 prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0 comp_tok = chunk.eval_count or 0
record_token_usage(endpoint, model, prompt_tok, comp_tok) await token_queue.put((endpoint, model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"): if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json() json_line = chunk.model_dump_json()
else: else:
json_line = json.dumps(chunk) json_line = orjson.dumps(chunk)
yield json_line.encode("utf-8") + b"\n" yield json_line.encode("utf-8") + b"\n"
else: else:
if is_openai_endpoint: if is_openai_endpoint:
@ -851,11 +857,11 @@ async def chat_proxy(request: Request):
response = async_gen.model_dump_json() response = async_gen.model_dump_json()
prompt_tok = async_gen.prompt_eval_count or 0 prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0 comp_tok = async_gen.eval_count or 0
record_token_usage(endpoint, model, prompt_tok, comp_tok) await token_queue.put((endpoint, model, prompt_tok, comp_tok))
json_line = ( json_line = (
response response
if hasattr(async_gen, "model_dump_json") if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen) else orjson.dumps(async_gen)
) )
yield json_line.encode("utf-8") + b"\n" yield json_line.encode("utf-8") + b"\n"
@ -882,7 +888,7 @@ async def embedding_proxy(request: Request):
# 1. Parse and validate request # 1. Parse and validate request
try: try:
body_bytes = await request.body() body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8")) payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model") model = payload.get("model")
prompt = payload.get("prompt") prompt = payload.get("prompt")
@ -897,7 +903,7 @@ async def embedding_proxy(request: Request):
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing required field 'prompt'" status_code=400, detail="Missing required field 'prompt'"
) )
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic # 2. Endpoint logic
@ -923,7 +929,7 @@ async def embedding_proxy(request: Request):
if hasattr(async_gen, "model_dump_json"): if hasattr(async_gen, "model_dump_json"):
json_line = async_gen.model_dump_json() json_line = async_gen.model_dump_json()
else: else:
json_line = json.dumps(async_gen) json_line = orjson.dumps(async_gen)
yield json_line.encode("utf-8") + b"\n" yield json_line.encode("utf-8") + b"\n"
finally: finally:
# Ensure counter is decremented even if an exception occurs # Ensure counter is decremented even if an exception occurs
@ -947,7 +953,7 @@ async def embed_proxy(request: Request):
# 1. Parse and validate request # 1. Parse and validate request
try: try:
body_bytes = await request.body() body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8")) payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model") model = payload.get("model")
_input = payload.get("input") _input = payload.get("input")
@ -963,7 +969,7 @@ async def embed_proxy(request: Request):
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing required field 'input'" status_code=400, detail="Missing required field 'input'"
) )
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic # 2. Endpoint logic
@ -989,7 +995,7 @@ async def embed_proxy(request: Request):
if hasattr(async_gen, "model_dump_json"): if hasattr(async_gen, "model_dump_json"):
json_line = async_gen.model_dump_json() json_line = async_gen.model_dump_json()
else: else:
json_line = json.dumps(async_gen) json_line = orjson.dumps(async_gen)
yield json_line.encode("utf-8") + b"\n" yield json_line.encode("utf-8") + b"\n"
finally: finally:
# Ensure counter is decremented even if an exception occurs # Ensure counter is decremented even if an exception occurs
@ -1011,7 +1017,7 @@ async def create_proxy(request: Request):
""" """
try: try:
body_bytes = await request.body() body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8")) payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model") model = payload.get("model")
quantize = payload.get("quantize") quantize = payload.get("quantize")
@ -1032,7 +1038,7 @@ async def create_proxy(request: Request):
raise HTTPException( raise HTTPException(
status_code=400, detail="You need to provide either from_ or files parameter!" status_code=400, detail="You need to provide either from_ or files parameter!"
) )
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
status_lists = [] status_lists = []
@ -1062,14 +1068,14 @@ async def show_proxy(request: Request, model: Optional[str] = None):
body_bytes = await request.body() body_bytes = await request.body()
if not model: if not model:
payload = json.loads(body_bytes.decode("utf-8")) payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model") model = payload.get("model")
if not model: if not model:
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing required field 'model'" status_code=400, detail="Missing required field 'model'"
) )
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic # 2. Endpoint logic
@ -1097,7 +1103,7 @@ async def copy_proxy(request: Request, source: Optional[str] = None, destination
body_bytes = await request.body() body_bytes = await request.body()
if not source and not destination: if not source and not destination:
payload = json.loads(body_bytes.decode("utf-8")) payload = orjson.loads(body_bytes.decode("utf-8"))
src = payload.get("source") src = payload.get("source")
dst = payload.get("destination") dst = payload.get("destination")
else: else:
@ -1112,7 +1118,7 @@ async def copy_proxy(request: Request, source: Optional[str] = None, destination
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing required field 'destination'" status_code=400, detail="Missing required field 'destination'"
) )
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 3. Iterate over all endpoints to copy the model on each endpoint # 3. Iterate over all endpoints to copy the model on each endpoint
@ -1141,14 +1147,14 @@ async def delete_proxy(request: Request, model: Optional[str] = None):
body_bytes = await request.body() body_bytes = await request.body()
if not model: if not model:
payload = json.loads(body_bytes.decode("utf-8")) payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model") model = payload.get("model")
if not model: if not model:
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing required field 'model'" status_code=400, detail="Missing required field 'model'"
) )
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints to delete the model on each endpoint # 2. Iterate over all endpoints to delete the model on each endpoint
@ -1176,7 +1182,7 @@ async def pull_proxy(request: Request, model: Optional[str] = None):
body_bytes = await request.body() body_bytes = await request.body()
if not model: if not model:
payload = json.loads(body_bytes.decode("utf-8")) payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model") model = payload.get("model")
insecure = payload.get("insecure") insecure = payload.get("insecure")
else: else:
@ -1186,7 +1192,7 @@ async def pull_proxy(request: Request, model: Optional[str] = None):
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing required field 'model'" status_code=400, detail="Missing required field 'model'"
) )
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints to pull the model # 2. Iterate over all endpoints to pull the model
@ -1218,7 +1224,7 @@ async def push_proxy(request: Request):
# 1. Parse and validate request # 1. Parse and validate request
try: try:
body_bytes = await request.body() body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8")) payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model") model = payload.get("model")
insecure = payload.get("insecure") insecure = payload.get("insecure")
@ -1227,7 +1233,7 @@ async def push_proxy(request: Request):
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing required field 'model'" status_code=400, detail="Missing required field 'model'"
) )
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints # 2. Iterate over all endpoints
@ -1385,7 +1391,7 @@ async def openai_embedding_proxy(request: Request):
# 1. Parse and validate request # 1. Parse and validate request
try: try:
body_bytes = await request.body() body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8")) payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model") model = payload.get("model")
doc = payload.get("input") doc = payload.get("input")
@ -1399,7 +1405,7 @@ async def openai_embedding_proxy(request: Request):
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing required field 'input'" status_code=400, detail="Missing required field 'input'"
) )
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic # 2. Endpoint logic
@ -1432,7 +1438,7 @@ async def openai_chat_completions_proxy(request: Request):
# 1. Parse and validate request # 1. Parse and validate request
try: try:
body_bytes = await request.body() body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8")) payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model") model = payload.get("model")
messages = payload.get("messages") messages = payload.get("messages")
@ -1483,7 +1489,7 @@ async def openai_chat_completions_proxy(request: Request):
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing required field 'messages' (must be a list)" status_code=400, detail="Missing required field 'messages' (must be a list)"
) )
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic # 2. Endpoint logic
@ -1501,7 +1507,7 @@ async def openai_chat_completions_proxy(request: Request):
data = ( data = (
chunk.model_dump_json() chunk.model_dump_json()
if hasattr(chunk, "model_dump_json") if hasattr(chunk, "model_dump_json")
else json.dumps(chunk) else orjson.dumps(chunk)
) )
if chunk.choices[0].delta.content is not None: if chunk.choices[0].delta.content is not None:
yield f"data: {data}\n\n".encode("utf-8") yield f"data: {data}\n\n".encode("utf-8")
@ -1509,11 +1515,11 @@ async def openai_chat_completions_proxy(request: Request):
else: else:
prompt_tok = async_gen.usage.prompt_tokens or 0 prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0 comp_tok = async_gen.usage.completion_tokens or 0
record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok) await token_queue.put((endpoint, model, prompt_tok, comp_tok))
json_line = ( json_line = (
async_gen.model_dump_json() async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json") if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen) else orjson.dumps(async_gen)
) )
yield json_line.encode("utf-8") + b"\n" yield json_line.encode("utf-8") + b"\n"
@ -1539,7 +1545,7 @@ async def openai_completions_proxy(request: Request):
# 1. Parse and validate request # 1. Parse and validate request
try: try:
body_bytes = await request.body() body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8")) payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model") model = payload.get("model")
prompt = payload.get("prompt") prompt = payload.get("prompt")
@ -1588,7 +1594,7 @@ async def openai_completions_proxy(request: Request):
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing required field 'prompt'" status_code=400, detail="Missing required field 'prompt'"
) )
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic # 2. Endpoint logic
@ -1607,7 +1613,7 @@ async def openai_completions_proxy(request: Request):
data = ( data = (
chunk.model_dump_json() chunk.model_dump_json()
if hasattr(chunk, "model_dump_json") if hasattr(chunk, "model_dump_json")
else json.dumps(chunk) else orjson.dumps(chunk)
) )
yield f"data: {data}\n\n".encode("utf-8") yield f"data: {data}\n\n".encode("utf-8")
# Final DONE event # Final DONE event
@ -1615,11 +1621,11 @@ async def openai_completions_proxy(request: Request):
else: else:
prompt_tok = async_gen.usage.prompt_tokens or 0 prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0 comp_tok = async_gen.usage.completion_tokens or 0
record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok) await token_queue.put((endpoint, model, prompt_tok, comp_tok))
json_line = ( json_line = (
async_gen.model_dump_json() async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json") if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen) else orjson.dumps(async_gen)
) )
yield json_line.encode("utf-8") + b"\n" yield json_line.encode("utf-8") + b"\n"
@ -1774,7 +1780,7 @@ async def startup_event() -> None:
app_state["connector"] = connector app_state["connector"] = connector
app_state["session"] = session app_state["session"] = session
asyncio.create_task(token_worker())
@app.on_event("shutdown") @app.on_event("shutdown")
async def shutdown_event() -> None: async def shutdown_event() -> None: