various performance improvements and json replacement orjson
This commit is contained in:
parent
c6c1059ede
commit
1427e98e6d
2 changed files with 70 additions and 62 deletions
130
router.py
130
router.py
|
|
@ -6,7 +6,7 @@ version: 0.4
|
|||
license: AGPL
|
||||
"""
|
||||
# -------------------------------------------------------------
|
||||
import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io
|
||||
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
|
@ -30,10 +30,11 @@ _models_cache: dict[str, tuple[Set[str], float]] = {}
|
|||
_error_cache: dict[str, float] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE Queues
|
||||
# Queues
|
||||
# ------------------------------------------------------------------
|
||||
_subscribers: Set[asyncio.Queue] = set()
|
||||
_subscribers_lock = asyncio.Lock()
|
||||
token_queue: asyncio.Queue[tuple[str, str, int, int]] = asyncio.Queue()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# aiohttp Global Sessions
|
||||
|
|
@ -125,6 +126,7 @@ default_headers={
|
|||
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
usage_lock = asyncio.Lock() # protects access to usage_counts
|
||||
token_usage_lock = asyncio.Lock()
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 4. Helperfunctions
|
||||
|
|
@ -192,12 +194,12 @@ def is_ext_openai_endpoint(endpoint: str) -> bool:
|
|||
|
||||
return True # It's an external OpenAI endpoint
|
||||
|
||||
def record_token_usage(endpoint: str, model: str, prompt: int = 0, completion: int = 0) -> None:
|
||||
async def _record():
|
||||
async with usage_lock: # reuse the same lock that protects usage_counts
|
||||
token_usage_counts[endpoint][model] += (prompt + completion)
|
||||
await publish_snapshot() # immediately broadcast the new totals
|
||||
asyncio.create_task(_record())
|
||||
async def token_worker() -> None:
|
||||
while True:
|
||||
endpoint, model, prompt, comp = await token_queue.get()
|
||||
async with token_usage_lock:
|
||||
token_usage_counts[endpoint][model] += (prompt + comp)
|
||||
await publish_snapshot()
|
||||
|
||||
class fetch:
|
||||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
|
|
@ -267,6 +269,8 @@ class fetch:
|
|||
set is returned.
|
||||
"""
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
if is_ext_openai_endpoint(endpoint):
|
||||
return set()
|
||||
try:
|
||||
async with client.get(f"{endpoint}/api/ps") as resp:
|
||||
await _ensure_success(resp)
|
||||
|
|
@ -428,18 +432,19 @@ def transform_images_to_data_urls(message_list):
|
|||
|
||||
class rechunk:
|
||||
def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse:
|
||||
now = time.perf_counter()
|
||||
if chunk.choices == [] and chunk.usage is not None:
|
||||
return ollama.ChatResponse(
|
||||
model=chunk.model,
|
||||
created_at=iso8601_ns(),
|
||||
done=True,
|
||||
done_reason='stop',
|
||||
total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000),
|
||||
total_duration=int((now - start_ts) * 1_000_000_000),
|
||||
load_duration=100000,
|
||||
prompt_eval_count=int(chunk.usage.prompt_tokens),
|
||||
prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
|
||||
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
|
||||
eval_count=int(chunk.usage.completion_tokens),
|
||||
eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000),
|
||||
eval_duration=int((now - start_ts) * 1_000_000_000),
|
||||
message={"role": "assistant"}
|
||||
)
|
||||
with_thinking = chunk.choices[0] if chunk.choices[0] else None
|
||||
|
|
@ -463,16 +468,17 @@ class rechunk:
|
|||
created_at=iso8601_ns(),
|
||||
done=True if chunk.usage is not None else False,
|
||||
done_reason=chunk.choices[0].finish_reason, #if chunk.choices[0].finish_reason is not None else None,
|
||||
total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||
load_duration=100000,
|
||||
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
|
||||
prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
|
||||
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
|
||||
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
|
||||
eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||
message=assistant_msg)
|
||||
return rechunk
|
||||
|
||||
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.GenerateResponse:
|
||||
now = time.perf_counter()
|
||||
with_thinking = chunk.choices[0] if chunk.choices[0] else None
|
||||
thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None
|
||||
rechunk = ollama.GenerateResponse(
|
||||
|
|
@ -480,12 +486,12 @@ class rechunk:
|
|||
created_at=iso8601_ns(),
|
||||
done=True if chunk.usage is not None else False,
|
||||
done_reason=chunk.choices[0].finish_reason,
|
||||
total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||
load_duration=10000,
|
||||
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
|
||||
prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
|
||||
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
|
||||
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
|
||||
eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||
response=chunk.choices[0].text or '',
|
||||
thinking=thinking)
|
||||
return rechunk
|
||||
|
|
@ -514,9 +520,9 @@ class rechunk:
|
|||
# ------------------------------------------------------------------
|
||||
async def publish_snapshot():
|
||||
async with usage_lock:
|
||||
snapshot = json.dumps({"usage_counts": usage_counts,
|
||||
snapshot = orjson.dumps({"usage_counts": usage_counts,
|
||||
"token_usage_counts": token_usage_counts,
|
||||
}, sort_keys=True)
|
||||
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
||||
async with _subscribers_lock:
|
||||
for q in _subscribers:
|
||||
# If the queue is full, drop the message to avoid back‑pressure.
|
||||
|
|
@ -650,7 +656,7 @@ async def proxy(request: Request):
|
|||
"""
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
prompt = payload.get("prompt")
|
||||
|
|
@ -674,7 +680,7 @@ async def proxy(request: Request):
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'prompt'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted."
|
||||
raise HTTPException(status_code=400, detail=error_msg) from e
|
||||
|
||||
|
|
@ -721,11 +727,11 @@ async def proxy(request: Request):
|
|||
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
|
||||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
json_line = json.dumps(chunk)
|
||||
json_line = orjson.dumps(chunk)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
else:
|
||||
if is_openai_endpoint:
|
||||
|
|
@ -735,11 +741,11 @@ async def proxy(request: Request):
|
|||
response = async_gen.model_dump_json()
|
||||
prompt_tok = async_gen.prompt_eval_count or 0
|
||||
comp_tok = async_gen.eval_count or 0
|
||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
response
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else json.dumps(async_gen)
|
||||
else orjson.dumps(async_gen)
|
||||
)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
|
||||
|
|
@ -764,7 +770,7 @@ async def chat_proxy(request: Request):
|
|||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
messages = payload.get("messages")
|
||||
|
|
@ -787,7 +793,7 @@ async def chat_proxy(request: Request):
|
|||
raise HTTPException(
|
||||
status_code=400, detail="`options` must be a JSON object"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
|
|
@ -837,11 +843,11 @@ async def chat_proxy(request: Request):
|
|||
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
||||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
json_line = json.dumps(chunk)
|
||||
json_line = orjson.dumps(chunk)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
else:
|
||||
if is_openai_endpoint:
|
||||
|
|
@ -851,11 +857,11 @@ async def chat_proxy(request: Request):
|
|||
response = async_gen.model_dump_json()
|
||||
prompt_tok = async_gen.prompt_eval_count or 0
|
||||
comp_tok = async_gen.eval_count or 0
|
||||
record_token_usage(endpoint, model, prompt_tok, comp_tok)
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
response
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else json.dumps(async_gen)
|
||||
else orjson.dumps(async_gen)
|
||||
)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
|
||||
|
|
@ -882,7 +888,7 @@ async def embedding_proxy(request: Request):
|
|||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
prompt = payload.get("prompt")
|
||||
|
|
@ -897,7 +903,7 @@ async def embedding_proxy(request: Request):
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'prompt'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
|
|
@ -923,7 +929,7 @@ async def embedding_proxy(request: Request):
|
|||
if hasattr(async_gen, "model_dump_json"):
|
||||
json_line = async_gen.model_dump_json()
|
||||
else:
|
||||
json_line = json.dumps(async_gen)
|
||||
json_line = orjson.dumps(async_gen)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
|
|
@ -947,7 +953,7 @@ async def embed_proxy(request: Request):
|
|||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
_input = payload.get("input")
|
||||
|
|
@ -963,7 +969,7 @@ async def embed_proxy(request: Request):
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'input'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
|
|
@ -989,7 +995,7 @@ async def embed_proxy(request: Request):
|
|||
if hasattr(async_gen, "model_dump_json"):
|
||||
json_line = async_gen.model_dump_json()
|
||||
else:
|
||||
json_line = json.dumps(async_gen)
|
||||
json_line = orjson.dumps(async_gen)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
|
|
@ -1011,7 +1017,7 @@ async def create_proxy(request: Request):
|
|||
"""
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
quantize = payload.get("quantize")
|
||||
|
|
@ -1032,7 +1038,7 @@ async def create_proxy(request: Request):
|
|||
raise HTTPException(
|
||||
status_code=400, detail="You need to provide either from_ or files parameter!"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
status_lists = []
|
||||
|
|
@ -1062,14 +1068,14 @@ async def show_proxy(request: Request, model: Optional[str] = None):
|
|||
body_bytes = await request.body()
|
||||
|
||||
if not model:
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
model = payload.get("model")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
|
|
@ -1097,7 +1103,7 @@ async def copy_proxy(request: Request, source: Optional[str] = None, destination
|
|||
body_bytes = await request.body()
|
||||
|
||||
if not source and not destination:
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
src = payload.get("source")
|
||||
dst = payload.get("destination")
|
||||
else:
|
||||
|
|
@ -1112,7 +1118,7 @@ async def copy_proxy(request: Request, source: Optional[str] = None, destination
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'destination'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 3. Iterate over all endpoints to copy the model on each endpoint
|
||||
|
|
@ -1141,14 +1147,14 @@ async def delete_proxy(request: Request, model: Optional[str] = None):
|
|||
body_bytes = await request.body()
|
||||
|
||||
if not model:
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
model = payload.get("model")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Iterate over all endpoints to delete the model on each endpoint
|
||||
|
|
@ -1176,7 +1182,7 @@ async def pull_proxy(request: Request, model: Optional[str] = None):
|
|||
body_bytes = await request.body()
|
||||
|
||||
if not model:
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
model = payload.get("model")
|
||||
insecure = payload.get("insecure")
|
||||
else:
|
||||
|
|
@ -1186,7 +1192,7 @@ async def pull_proxy(request: Request, model: Optional[str] = None):
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Iterate over all endpoints to pull the model
|
||||
|
|
@ -1218,7 +1224,7 @@ async def push_proxy(request: Request):
|
|||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
insecure = payload.get("insecure")
|
||||
|
|
@ -1227,7 +1233,7 @@ async def push_proxy(request: Request):
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Iterate over all endpoints
|
||||
|
|
@ -1385,7 +1391,7 @@ async def openai_embedding_proxy(request: Request):
|
|||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
doc = payload.get("input")
|
||||
|
|
@ -1399,7 +1405,7 @@ async def openai_embedding_proxy(request: Request):
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'input'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
|
|
@ -1432,7 +1438,7 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
messages = payload.get("messages")
|
||||
|
|
@ -1483,7 +1489,7 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'messages' (must be a list)"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
|
|
@ -1501,7 +1507,7 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
data = (
|
||||
chunk.model_dump_json()
|
||||
if hasattr(chunk, "model_dump_json")
|
||||
else json.dumps(chunk)
|
||||
else orjson.dumps(chunk)
|
||||
)
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
|
|
@ -1509,11 +1515,11 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
else:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok)
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else json.dumps(async_gen)
|
||||
else orjson.dumps(async_gen)
|
||||
)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
|
||||
|
|
@ -1539,7 +1545,7 @@ async def openai_completions_proxy(request: Request):
|
|||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = json.loads(body_bytes.decode("utf-8"))
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
prompt = payload.get("prompt")
|
||||
|
|
@ -1588,7 +1594,7 @@ async def openai_completions_proxy(request: Request):
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'prompt'"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
|
|
@ -1607,7 +1613,7 @@ async def openai_completions_proxy(request: Request):
|
|||
data = (
|
||||
chunk.model_dump_json()
|
||||
if hasattr(chunk, "model_dump_json")
|
||||
else json.dumps(chunk)
|
||||
else orjson.dumps(chunk)
|
||||
)
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
# Final DONE event
|
||||
|
|
@ -1615,11 +1621,11 @@ async def openai_completions_proxy(request: Request):
|
|||
else:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok)
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else json.dumps(async_gen)
|
||||
else orjson.dumps(async_gen)
|
||||
)
|
||||
yield json_line.encode("utf-8") + b"\n"
|
||||
|
||||
|
|
@ -1774,7 +1780,7 @@ async def startup_event() -> None:
|
|||
|
||||
app_state["connector"] = connector
|
||||
app_state["session"] = session
|
||||
|
||||
asyncio.create_task(token_worker())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event() -> None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue