various performance improvements and json replacement orjson
This commit is contained in:
parent
c6c1059ede
commit
1427e98e6d
2 changed files with 70 additions and 62 deletions
|
|
@ -1,6 +1,7 @@
|
||||||
aiohappyeyeballs==2.6.1
|
aiohappyeyeballs==2.6.1
|
||||||
aiohttp==3.12.15
|
aiohttp==3.12.15
|
||||||
aiosignal==1.4.0
|
aiosignal==1.4.0
|
||||||
|
annotated-doc==0.0.3
|
||||||
annotated-types==0.7.0
|
annotated-types==0.7.0
|
||||||
anyio==4.10.0
|
anyio==4.10.0
|
||||||
async-timeout==5.0.1
|
async-timeout==5.0.1
|
||||||
|
|
@ -20,6 +21,7 @@ jiter==0.10.0
|
||||||
multidict==6.6.4
|
multidict==6.6.4
|
||||||
ollama==0.6.0
|
ollama==0.6.0
|
||||||
openai==1.102.0
|
openai==1.102.0
|
||||||
|
orjson==3.11.4
|
||||||
pillow==11.3.0
|
pillow==11.3.0
|
||||||
propcache==0.3.2
|
propcache==0.3.2
|
||||||
pydantic==2.11.7
|
pydantic==2.11.7
|
||||||
|
|
|
||||||
130
router.py
130
router.py
|
|
@ -6,7 +6,7 @@ version: 0.4
|
||||||
license: AGPL
|
license: AGPL
|
||||||
"""
|
"""
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io
|
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Set, List, Optional
|
from typing import Dict, Set, List, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
@ -30,10 +30,11 @@ _models_cache: dict[str, tuple[Set[str], float]] = {}
|
||||||
_error_cache: dict[str, float] = {}
|
_error_cache: dict[str, float] = {}
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# SSE Queues
|
# Queues
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
_subscribers: Set[asyncio.Queue] = set()
|
_subscribers: Set[asyncio.Queue] = set()
|
||||||
_subscribers_lock = asyncio.Lock()
|
_subscribers_lock = asyncio.Lock()
|
||||||
|
token_queue: asyncio.Queue[tuple[str, str, int, int]] = asyncio.Queue()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# aiohttp Global Sessions
|
# aiohttp Global Sessions
|
||||||
|
|
@ -125,6 +126,7 @@ default_headers={
|
||||||
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||||
token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||||
usage_lock = asyncio.Lock() # protects access to usage_counts
|
usage_lock = asyncio.Lock() # protects access to usage_counts
|
||||||
|
token_usage_lock = asyncio.Lock()
|
||||||
|
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
# 4. Helperfunctions
|
# 4. Helperfunctions
|
||||||
|
|
@ -192,12 +194,12 @@ def is_ext_openai_endpoint(endpoint: str) -> bool:
|
||||||
|
|
||||||
return True # It's an external OpenAI endpoint
|
return True # It's an external OpenAI endpoint
|
||||||
|
|
||||||
def record_token_usage(endpoint: str, model: str, prompt: int = 0, completion: int = 0) -> None:
|
async def token_worker() -> None:
|
||||||
async def _record():
|
while True:
|
||||||
async with usage_lock: # reuse the same lock that protects usage_counts
|
endpoint, model, prompt, comp = await token_queue.get()
|
||||||
token_usage_counts[endpoint][model] += (prompt + completion)
|
async with token_usage_lock:
|
||||||
await publish_snapshot() # immediately broadcast the new totals
|
token_usage_counts[endpoint][model] += (prompt + comp)
|
||||||
asyncio.create_task(_record())
|
await publish_snapshot()
|
||||||
|
|
||||||
class fetch:
|
class fetch:
|
||||||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||||
|
|
@ -267,6 +269,8 @@ class fetch:
|
||||||
set is returned.
|
set is returned.
|
||||||
"""
|
"""
|
||||||
client: aiohttp.ClientSession = app_state["session"]
|
client: aiohttp.ClientSession = app_state["session"]
|
||||||
|
if is_ext_openai_endpoint(endpoint):
|
||||||
|
return set()
|
||||||
try:
|
try:
|
||||||
async with client.get(f"{endpoint}/api/ps") as resp:
|
async with client.get(f"{endpoint}/api/ps") as resp:
|
||||||
await _ensure_success(resp)
|
await _ensure_success(resp)
|
||||||
|
|
@ -428,18 +432,19 @@ def transform_images_to_data_urls(message_list):
|
||||||
|
|
||||||
class rechunk:
|
class rechunk:
|
||||||
def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse:
|
def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse:
|
||||||
|
now = time.perf_counter()
|
||||||
if chunk.choices == [] and chunk.usage is not None:
|
if chunk.choices == [] and chunk.usage is not None:
|
||||||
return ollama.ChatResponse(
|
return ollama.ChatResponse(
|
||||||
model=chunk.model,
|
model=chunk.model,
|
||||||
created_at=iso8601_ns(),
|
created_at=iso8601_ns(),
|
||||||
done=True,
|
done=True,
|
||||||
done_reason='stop',
|
done_reason='stop',
|
||||||
total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000),
|
total_duration=int((now - start_ts) * 1_000_000_000),
|
||||||
load_duration=100000,
|
load_duration=100000,
|
||||||
prompt_eval_count=int(chunk.usage.prompt_tokens),
|
prompt_eval_count=int(chunk.usage.prompt_tokens),
|
||||||
prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
|
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
|
||||||
eval_count=int(chunk.usage.completion_tokens),
|
eval_count=int(chunk.usage.completion_tokens),
|
||||||
eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000),
|
eval_duration=int((now - start_ts) * 1_000_000_000),
|
||||||
message={"role": "assistant"}
|
message={"role": "assistant"}
|
||||||
)
|
)
|
||||||
with_thinking = chunk.choices[0] if chunk.choices[0] else None
|
with_thinking = chunk.choices[0] if chunk.choices[0] else None
|
||||||
|
|
@ -463,16 +468,17 @@ class rechunk:
|
||||||
created_at=iso8601_ns(),
|
created_at=iso8601_ns(),
|
||||||
done=True if chunk.usage is not None else False,
|
done=True if chunk.usage is not None else False,
|
||||||
done_reason=chunk.choices[0].finish_reason, #if chunk.choices[0].finish_reason is not None else None,
|
done_reason=chunk.choices[0].finish_reason, #if chunk.choices[0].finish_reason is not None else None,
|
||||||
total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||||
load_duration=100000,
|
load_duration=100000,
|
||||||
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
|
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
|
||||||
prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
|
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
|
||||||
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
|
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
|
||||||
eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||||
message=assistant_msg)
|
message=assistant_msg)
|
||||||
return rechunk
|
return rechunk
|
||||||
|
|
||||||
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.GenerateResponse:
|
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.GenerateResponse:
|
||||||
|
now = time.perf_counter()
|
||||||
with_thinking = chunk.choices[0] if chunk.choices[0] else None
|
with_thinking = chunk.choices[0] if chunk.choices[0] else None
|
||||||
thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None
|
thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None
|
||||||
rechunk = ollama.GenerateResponse(
|
rechunk = ollama.GenerateResponse(
|
||||||
|
|
@ -480,12 +486,12 @@ class rechunk:
|
||||||
created_at=iso8601_ns(),
|
created_at=iso8601_ns(),
|
||||||
done=True if chunk.usage is not None else False,
|
done=True if chunk.usage is not None else False,
|
||||||
done_reason=chunk.choices[0].finish_reason,
|
done_reason=chunk.choices[0].finish_reason,
|
||||||
total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||||
load_duration=10000,
|
load_duration=10000,
|
||||||
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
|
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
|
||||||
prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
|
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
|
||||||
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
|
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
|
||||||
eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||||
response=chunk.choices[0].text or '',
|
response=chunk.choices[0].text or '',
|
||||||
thinking=thinking)
|
thinking=thinking)
|
||||||
return rechunk
|
return rechunk
|
||||||
|
|
@ -514,9 +520,9 @@ class rechunk:
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
async def publish_snapshot():
|
async def publish_snapshot():
|
||||||
async with usage_lock:
|
async with usage_lock:
|
||||||
snapshot = json.dumps({"usage_counts": usage_counts,
|
snapshot = orjson.dumps({"usage_counts": usage_counts,
|
||||||
"token_usage_counts": token_usage_counts,
|
"token_usage_counts": token_usage_counts,
|
||||||
}, sort_keys=True)
|
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
||||||
async with _subscribers_lock:
|
async with _subscribers_lock:
|
||||||
for q in _subscribers:
|
for q in _subscribers:
|
||||||
# If the queue is full, drop the message to avoid back‑pressure.
|
# If the queue is full, drop the message to avoid back‑pressure.
|
||||||
|
|
@ -650,7 +656,7 @@ async def proxy(request: Request):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
prompt = payload.get("prompt")
|
prompt = payload.get("prompt")
|
||||||
|
|
@ -674,7 +680,7 @@ async def proxy(request: Request):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'prompt'"
|
status_code=400, detail="Missing required field 'prompt'"
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except orjson.JSONDecodeError as e:
|
||||||
error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted."
|
error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted."
|
||||||
raise HTTPException(status_code=400, detail=error_msg) from e
|
raise HTTPException(status_code=400, detail=error_msg) from e
|
||||||
|
|
||||||
|
|
@ -721,11 +727,11 @@ async def proxy(request: Request):
|
||||||
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
|
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
|
||||||
prompt_tok = chunk.prompt_eval_count or 0
|
prompt_tok = chunk.prompt_eval_count or 0
|
||||||
comp_tok = chunk.eval_count or 0
|
comp_tok = chunk.eval_count or 0
|
||||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||||
if hasattr(chunk, "model_dump_json"):
|
if hasattr(chunk, "model_dump_json"):
|
||||||
json_line = chunk.model_dump_json()
|
json_line = chunk.model_dump_json()
|
||||||
else:
|
else:
|
||||||
json_line = json.dumps(chunk)
|
json_line = orjson.dumps(chunk)
|
||||||
yield json_line.encode("utf-8") + b"\n"
|
yield json_line.encode("utf-8") + b"\n"
|
||||||
else:
|
else:
|
||||||
if is_openai_endpoint:
|
if is_openai_endpoint:
|
||||||
|
|
@ -735,11 +741,11 @@ async def proxy(request: Request):
|
||||||
response = async_gen.model_dump_json()
|
response = async_gen.model_dump_json()
|
||||||
prompt_tok = async_gen.prompt_eval_count or 0
|
prompt_tok = async_gen.prompt_eval_count or 0
|
||||||
comp_tok = async_gen.eval_count or 0
|
comp_tok = async_gen.eval_count or 0
|
||||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||||
json_line = (
|
json_line = (
|
||||||
response
|
response
|
||||||
if hasattr(async_gen, "model_dump_json")
|
if hasattr(async_gen, "model_dump_json")
|
||||||
else json.dumps(async_gen)
|
else orjson.dumps(async_gen)
|
||||||
)
|
)
|
||||||
yield json_line.encode("utf-8") + b"\n"
|
yield json_line.encode("utf-8") + b"\n"
|
||||||
|
|
||||||
|
|
@ -764,7 +770,7 @@ async def chat_proxy(request: Request):
|
||||||
# 1. Parse and validate request
|
# 1. Parse and validate request
|
||||||
try:
|
try:
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
messages = payload.get("messages")
|
messages = payload.get("messages")
|
||||||
|
|
@ -787,7 +793,7 @@ async def chat_proxy(request: Request):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="`options` must be a JSON object"
|
status_code=400, detail="`options` must be a JSON object"
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except orjson.JSONDecodeError as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
|
|
@ -837,11 +843,11 @@ async def chat_proxy(request: Request):
|
||||||
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
||||||
prompt_tok = chunk.prompt_eval_count or 0
|
prompt_tok = chunk.prompt_eval_count or 0
|
||||||
comp_tok = chunk.eval_count or 0
|
comp_tok = chunk.eval_count or 0
|
||||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||||
if hasattr(chunk, "model_dump_json"):
|
if hasattr(chunk, "model_dump_json"):
|
||||||
json_line = chunk.model_dump_json()
|
json_line = chunk.model_dump_json()
|
||||||
else:
|
else:
|
||||||
json_line = json.dumps(chunk)
|
json_line = orjson.dumps(chunk)
|
||||||
yield json_line.encode("utf-8") + b"\n"
|
yield json_line.encode("utf-8") + b"\n"
|
||||||
else:
|
else:
|
||||||
if is_openai_endpoint:
|
if is_openai_endpoint:
|
||||||
|
|
@ -851,11 +857,11 @@ async def chat_proxy(request: Request):
|
||||||
response = async_gen.model_dump_json()
|
response = async_gen.model_dump_json()
|
||||||
prompt_tok = async_gen.prompt_eval_count or 0
|
prompt_tok = async_gen.prompt_eval_count or 0
|
||||||
comp_tok = async_gen.eval_count or 0
|
comp_tok = async_gen.eval_count or 0
|
||||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||||
json_line = (
|
json_line = (
|
||||||
response
|
response
|
||||||
if hasattr(async_gen, "model_dump_json")
|
if hasattr(async_gen, "model_dump_json")
|
||||||
else json.dumps(async_gen)
|
else orjson.dumps(async_gen)
|
||||||
)
|
)
|
||||||
yield json_line.encode("utf-8") + b"\n"
|
yield json_line.encode("utf-8") + b"\n"
|
||||||
|
|
||||||
|
|
@ -882,7 +888,7 @@ async def embedding_proxy(request: Request):
|
||||||
# 1. Parse and validate request
|
# 1. Parse and validate request
|
||||||
try:
|
try:
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
prompt = payload.get("prompt")
|
prompt = payload.get("prompt")
|
||||||
|
|
@ -897,7 +903,7 @@ async def embedding_proxy(request: Request):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'prompt'"
|
status_code=400, detail="Missing required field 'prompt'"
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except orjson.JSONDecodeError as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
|
|
@ -923,7 +929,7 @@ async def embedding_proxy(request: Request):
|
||||||
if hasattr(async_gen, "model_dump_json"):
|
if hasattr(async_gen, "model_dump_json"):
|
||||||
json_line = async_gen.model_dump_json()
|
json_line = async_gen.model_dump_json()
|
||||||
else:
|
else:
|
||||||
json_line = json.dumps(async_gen)
|
json_line = orjson.dumps(async_gen)
|
||||||
yield json_line.encode("utf-8") + b"\n"
|
yield json_line.encode("utf-8") + b"\n"
|
||||||
finally:
|
finally:
|
||||||
# Ensure counter is decremented even if an exception occurs
|
# Ensure counter is decremented even if an exception occurs
|
||||||
|
|
@ -947,7 +953,7 @@ async def embed_proxy(request: Request):
|
||||||
# 1. Parse and validate request
|
# 1. Parse and validate request
|
||||||
try:
|
try:
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
_input = payload.get("input")
|
_input = payload.get("input")
|
||||||
|
|
@ -963,7 +969,7 @@ async def embed_proxy(request: Request):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'input'"
|
status_code=400, detail="Missing required field 'input'"
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except orjson.JSONDecodeError as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
|
|
@ -989,7 +995,7 @@ async def embed_proxy(request: Request):
|
||||||
if hasattr(async_gen, "model_dump_json"):
|
if hasattr(async_gen, "model_dump_json"):
|
||||||
json_line = async_gen.model_dump_json()
|
json_line = async_gen.model_dump_json()
|
||||||
else:
|
else:
|
||||||
json_line = json.dumps(async_gen)
|
json_line = orjson.dumps(async_gen)
|
||||||
yield json_line.encode("utf-8") + b"\n"
|
yield json_line.encode("utf-8") + b"\n"
|
||||||
finally:
|
finally:
|
||||||
# Ensure counter is decremented even if an exception occurs
|
# Ensure counter is decremented even if an exception occurs
|
||||||
|
|
@ -1011,7 +1017,7 @@ async def create_proxy(request: Request):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
quantize = payload.get("quantize")
|
quantize = payload.get("quantize")
|
||||||
|
|
@ -1032,7 +1038,7 @@ async def create_proxy(request: Request):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="You need to provide either from_ or files parameter!"
|
status_code=400, detail="You need to provide either from_ or files parameter!"
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except orjson.JSONDecodeError as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
status_lists = []
|
status_lists = []
|
||||||
|
|
@ -1062,14 +1068,14 @@ async def show_proxy(request: Request, model: Optional[str] = None):
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'model'"
|
status_code=400, detail="Missing required field 'model'"
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except orjson.JSONDecodeError as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
|
|
@ -1097,7 +1103,7 @@ async def copy_proxy(request: Request, source: Optional[str] = None, destination
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
|
|
||||||
if not source and not destination:
|
if not source and not destination:
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
src = payload.get("source")
|
src = payload.get("source")
|
||||||
dst = payload.get("destination")
|
dst = payload.get("destination")
|
||||||
else:
|
else:
|
||||||
|
|
@ -1112,7 +1118,7 @@ async def copy_proxy(request: Request, source: Optional[str] = None, destination
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'destination'"
|
status_code=400, detail="Missing required field 'destination'"
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except orjson.JSONDecodeError as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 3. Iterate over all endpoints to copy the model on each endpoint
|
# 3. Iterate over all endpoints to copy the model on each endpoint
|
||||||
|
|
@ -1141,14 +1147,14 @@ async def delete_proxy(request: Request, model: Optional[str] = None):
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'model'"
|
status_code=400, detail="Missing required field 'model'"
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except orjson.JSONDecodeError as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Iterate over all endpoints to delete the model on each endpoint
|
# 2. Iterate over all endpoints to delete the model on each endpoint
|
||||||
|
|
@ -1176,7 +1182,7 @@ async def pull_proxy(request: Request, model: Optional[str] = None):
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
insecure = payload.get("insecure")
|
insecure = payload.get("insecure")
|
||||||
else:
|
else:
|
||||||
|
|
@ -1186,7 +1192,7 @@ async def pull_proxy(request: Request, model: Optional[str] = None):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'model'"
|
status_code=400, detail="Missing required field 'model'"
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except orjson.JSONDecodeError as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Iterate over all endpoints to pull the model
|
# 2. Iterate over all endpoints to pull the model
|
||||||
|
|
@ -1218,7 +1224,7 @@ async def push_proxy(request: Request):
|
||||||
# 1. Parse and validate request
|
# 1. Parse and validate request
|
||||||
try:
|
try:
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
insecure = payload.get("insecure")
|
insecure = payload.get("insecure")
|
||||||
|
|
@ -1227,7 +1233,7 @@ async def push_proxy(request: Request):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'model'"
|
status_code=400, detail="Missing required field 'model'"
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except orjson.JSONDecodeError as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Iterate over all endpoints
|
# 2. Iterate over all endpoints
|
||||||
|
|
@ -1385,7 +1391,7 @@ async def openai_embedding_proxy(request: Request):
|
||||||
# 1. Parse and validate request
|
# 1. Parse and validate request
|
||||||
try:
|
try:
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
doc = payload.get("input")
|
doc = payload.get("input")
|
||||||
|
|
@ -1399,7 +1405,7 @@ async def openai_embedding_proxy(request: Request):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'input'"
|
status_code=400, detail="Missing required field 'input'"
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except orjson.JSONDecodeError as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
|
|
@ -1432,7 +1438,7 @@ async def openai_chat_completions_proxy(request: Request):
|
||||||
# 1. Parse and validate request
|
# 1. Parse and validate request
|
||||||
try:
|
try:
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
messages = payload.get("messages")
|
messages = payload.get("messages")
|
||||||
|
|
@ -1483,7 +1489,7 @@ async def openai_chat_completions_proxy(request: Request):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'messages' (must be a list)"
|
status_code=400, detail="Missing required field 'messages' (must be a list)"
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except orjson.JSONDecodeError as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
|
|
@ -1501,7 +1507,7 @@ async def openai_chat_completions_proxy(request: Request):
|
||||||
data = (
|
data = (
|
||||||
chunk.model_dump_json()
|
chunk.model_dump_json()
|
||||||
if hasattr(chunk, "model_dump_json")
|
if hasattr(chunk, "model_dump_json")
|
||||||
else json.dumps(chunk)
|
else orjson.dumps(chunk)
|
||||||
)
|
)
|
||||||
if chunk.choices[0].delta.content is not None:
|
if chunk.choices[0].delta.content is not None:
|
||||||
yield f"data: {data}\n\n".encode("utf-8")
|
yield f"data: {data}\n\n".encode("utf-8")
|
||||||
|
|
@ -1509,11 +1515,11 @@ async def openai_chat_completions_proxy(request: Request):
|
||||||
else:
|
else:
|
||||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||||
comp_tok = async_gen.usage.completion_tokens or 0
|
comp_tok = async_gen.usage.completion_tokens or 0
|
||||||
record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok)
|
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||||
json_line = (
|
json_line = (
|
||||||
async_gen.model_dump_json()
|
async_gen.model_dump_json()
|
||||||
if hasattr(async_gen, "model_dump_json")
|
if hasattr(async_gen, "model_dump_json")
|
||||||
else json.dumps(async_gen)
|
else orjson.dumps(async_gen)
|
||||||
)
|
)
|
||||||
yield json_line.encode("utf-8") + b"\n"
|
yield json_line.encode("utf-8") + b"\n"
|
||||||
|
|
||||||
|
|
@ -1539,7 +1545,7 @@ async def openai_completions_proxy(request: Request):
|
||||||
# 1. Parse and validate request
|
# 1. Parse and validate request
|
||||||
try:
|
try:
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
payload = json.loads(body_bytes.decode("utf-8"))
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||||
|
|
||||||
model = payload.get("model")
|
model = payload.get("model")
|
||||||
prompt = payload.get("prompt")
|
prompt = payload.get("prompt")
|
||||||
|
|
@ -1588,7 +1594,7 @@ async def openai_completions_proxy(request: Request):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Missing required field 'prompt'"
|
status_code=400, detail="Missing required field 'prompt'"
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except orjson.JSONDecodeError as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
|
|
@ -1607,7 +1613,7 @@ async def openai_completions_proxy(request: Request):
|
||||||
data = (
|
data = (
|
||||||
chunk.model_dump_json()
|
chunk.model_dump_json()
|
||||||
if hasattr(chunk, "model_dump_json")
|
if hasattr(chunk, "model_dump_json")
|
||||||
else json.dumps(chunk)
|
else orjson.dumps(chunk)
|
||||||
)
|
)
|
||||||
yield f"data: {data}\n\n".encode("utf-8")
|
yield f"data: {data}\n\n".encode("utf-8")
|
||||||
# Final DONE event
|
# Final DONE event
|
||||||
|
|
@ -1615,11 +1621,11 @@ async def openai_completions_proxy(request: Request):
|
||||||
else:
|
else:
|
||||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||||
comp_tok = async_gen.usage.completion_tokens or 0
|
comp_tok = async_gen.usage.completion_tokens or 0
|
||||||
record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok)
|
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||||
json_line = (
|
json_line = (
|
||||||
async_gen.model_dump_json()
|
async_gen.model_dump_json()
|
||||||
if hasattr(async_gen, "model_dump_json")
|
if hasattr(async_gen, "model_dump_json")
|
||||||
else json.dumps(async_gen)
|
else orjson.dumps(async_gen)
|
||||||
)
|
)
|
||||||
yield json_line.encode("utf-8") + b"\n"
|
yield json_line.encode("utf-8") + b"\n"
|
||||||
|
|
||||||
|
|
@ -1774,7 +1780,7 @@ async def startup_event() -> None:
|
||||||
|
|
||||||
app_state["connector"] = connector
|
app_state["connector"] = connector
|
||||||
app_state["session"] = session
|
app_state["session"] = session
|
||||||
|
asyncio.create_task(token_worker())
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
@app.on_event("shutdown")
|
||||||
async def shutdown_event() -> None:
|
async def shutdown_event() -> None:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue