diff --git a/.gitignore b/.gitignore index 4bb65cd..702c855 100644 --- a/.gitignore +++ b/.gitignore @@ -51,10 +51,6 @@ cython_debug/ # VS Code files for those working on multiple tools .vscode/* -.vscode/settings.json -!.vscode/tasks.json -!.vscode/launch.json -!.vscode/extensions.json *.code-workspace # Local History for Visual Studio Code diff --git a/requirements.txt b/requirements.txt index f3ad896..e296839 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ aiohappyeyeballs==2.6.1 aiohttp==3.12.15 aiosignal==1.4.0 +annotated-doc==0.0.3 annotated-types==0.7.0 anyio==4.10.0 async-timeout==5.0.1 @@ -20,6 +21,7 @@ jiter==0.10.0 multidict==6.6.4 ollama==0.6.0 openai==1.102.0 +orjson==3.11.4 pillow==11.3.0 propcache==0.3.2 pydantic==2.11.7 diff --git a/router.py b/router.py index 2deb639..dff8aa9 100644 --- a/router.py +++ b/router.py @@ -6,7 +6,7 @@ version: 0.4 license: AGPL """ # ------------------------------------------------------------- -import json, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io +import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, datetime, random, base64, io from pathlib import Path from typing import Dict, Set, List, Optional from urllib.parse import urlparse @@ -25,23 +25,26 @@ from PIL import Image # ------------------------------------------------------------------ # Successful results are cached for 300s _models_cache: dict[str, tuple[Set[str], float]] = {} +_loaded_models_cache: dict[str, tuple[Set[str], float]] = {} # Transient errors are cached for 1s – the key stays until the # timeout expires, after which the endpoint will be queried again. _error_cache: dict[str, float] = {} # ------------------------------------------------------------------ -# SSE Queues +# Queues # ------------------------------------------------------------------ _subscribers: Set[asyncio.Queue] = set() _subscribers_lock = asyncio.Lock() +token_queue: asyncio.Queue[tuple[str, str, int, int]] = asyncio.Queue() # ------------------------------------------------------------------ -# aiohttp Global Sessions +# Globals # ------------------------------------------------------------------ app_state = { "session": None, "connector": None, } +token_worker_task: asyncio.Task | None = None # ------------------------------------------------------------- # 1. Configuration loader @@ -125,6 +128,7 @@ default_headers={ usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) usage_lock = asyncio.Lock() # protects access to usage_counts +token_usage_lock = asyncio.Lock() # ------------------------------------------------------------- # 4. Helperfunctions @@ -192,12 +196,12 @@ def is_ext_openai_endpoint(endpoint: str) -> bool: return True # It's an external OpenAI endpoint -def record_token_usage(endpoint: str, model: str, prompt: int = 0, completion: int = 0) -> None: - async def _record(): - async with usage_lock: # reuse the same lock that protects usage_counts - token_usage_counts[endpoint][model] += (prompt + completion) - await publish_snapshot() # immediately broadcast the new totals - asyncio.create_task(_record()) +async def token_worker() -> None: + while True: + endpoint, model, prompt, comp = await token_queue.get() + async with token_usage_lock: + token_usage_counts[endpoint][model] += (prompt + comp) + await publish_snapshot() class fetch: async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: @@ -223,7 +227,7 @@ class fetch: del _models_cache[endpoint] if endpoint in _error_cache: - if _is_fresh(_error_cache[endpoint], 1): + if _is_fresh(_error_cache[endpoint], 10): # Still within the short error TTL – pretend nothing is available return set() else: @@ -266,6 +270,21 @@ class fetch: loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty set is returned. """ + if is_ext_openai_endpoint(endpoint): + return set() + if endpoint in _loaded_models_cache: + models, cached_at = _loaded_models_cache[endpoint] + if _is_fresh(cached_at, 30): + return models + else: + # stale entry – drop it + del _loaded_models_cache[endpoint] + + if endpoint in _error_cache: + if _is_fresh(_error_cache[endpoint], 10): + return set() + else: + del _error_cache[endpoint] client: aiohttp.ClientSession = app_state["session"] try: async with client.get(f"{endpoint}/api/ps") as resp: @@ -274,6 +293,7 @@ class fetch: # The response format is: # {"models": [{"name": "model1"}, {"name": "model2"}]} models = {m.get("name") for m in data.get("models", []) if m.get("name")} + _loaded_models_cache[endpoint] = (models, time.time()) return models except Exception as e: # If anything goes wrong we simply assume the endpoint has no models @@ -428,18 +448,19 @@ def transform_images_to_data_urls(message_list): class rechunk: def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse: + now = time.perf_counter() if chunk.choices == [] and chunk.usage is not None: return ollama.ChatResponse( model=chunk.model, created_at=iso8601_ns(), done=True, done_reason='stop', - total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000), + total_duration=int((now - start_ts) * 1_000_000_000), load_duration=100000, prompt_eval_count=int(chunk.usage.prompt_tokens), - prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)), + prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)), eval_count=int(chunk.usage.completion_tokens), - eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000), + eval_duration=int((now - start_ts) * 1_000_000_000), message={"role": "assistant"} ) with_thinking = chunk.choices[0] if chunk.choices[0] else None @@ -463,16 +484,17 @@ class rechunk: created_at=iso8601_ns(), done=True if chunk.usage is not None else False, done_reason=chunk.choices[0].finish_reason, #if chunk.choices[0].finish_reason is not None else None, - total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, + total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, load_duration=100000, prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0, - prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0, + prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0, eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0, - eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, + eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, message=assistant_msg) return rechunk def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.GenerateResponse: + now = time.perf_counter() with_thinking = chunk.choices[0] if chunk.choices[0] else None thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None rechunk = ollama.GenerateResponse( @@ -480,12 +502,12 @@ class rechunk: created_at=iso8601_ns(), done=True if chunk.usage is not None else False, done_reason=chunk.choices[0].finish_reason, - total_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, + total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, load_duration=10000, prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0, - prompt_eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0, + prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0, eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0, - eval_duration=int((time.perf_counter() - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, + eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0, response=chunk.choices[0].text or '', thinking=thinking) return rechunk @@ -514,9 +536,9 @@ class rechunk: # ------------------------------------------------------------------ async def publish_snapshot(): async with usage_lock: - snapshot = json.dumps({"usage_counts": usage_counts, + snapshot = orjson.dumps({"usage_counts": usage_counts, "token_usage_counts": token_usage_counts, - }, sort_keys=True) + }, option=orjson.OPT_SORT_KEYS).decode("utf-8") async with _subscribers_lock: for q in _subscribers: # If the queue is full, drop the message to avoid back‑pressure. @@ -650,7 +672,7 @@ async def proxy(request: Request): """ try: body_bytes = await request.body() - payload = json.loads(body_bytes.decode("utf-8")) + payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") prompt = payload.get("prompt") @@ -674,7 +696,7 @@ async def proxy(request: Request): raise HTTPException( status_code=400, detail="Missing required field 'prompt'" ) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted." raise HTTPException(status_code=400, detail=error_msg) from e @@ -721,11 +743,11 @@ async def proxy(request: Request): chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts) prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 - record_token_usage(endpoint, model, prompt_tok, comp_tok) + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: - json_line = json.dumps(chunk) + json_line = orjson.dumps(chunk) yield json_line.encode("utf-8") + b"\n" else: if is_openai_endpoint: @@ -735,11 +757,11 @@ async def proxy(request: Request): response = async_gen.model_dump_json() prompt_tok = async_gen.prompt_eval_count or 0 comp_tok = async_gen.eval_count or 0 - record_token_usage(endpoint, model, prompt_tok, comp_tok) + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( response if hasattr(async_gen, "model_dump_json") - else json.dumps(async_gen) + else orjson.dumps(async_gen) ) yield json_line.encode("utf-8") + b"\n" @@ -764,7 +786,7 @@ async def chat_proxy(request: Request): # 1. Parse and validate request try: body_bytes = await request.body() - payload = json.loads(body_bytes.decode("utf-8")) + payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") messages = payload.get("messages") @@ -787,7 +809,7 @@ async def chat_proxy(request: Request): raise HTTPException( status_code=400, detail="`options` must be a JSON object" ) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic @@ -837,11 +859,11 @@ async def chat_proxy(request: Request): # `chunk` can be a dict or a pydantic model – dump to JSON safely prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 - record_token_usage(endpoint, model, prompt_tok, comp_tok) + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: - json_line = json.dumps(chunk) + json_line = orjson.dumps(chunk) yield json_line.encode("utf-8") + b"\n" else: if is_openai_endpoint: @@ -851,11 +873,11 @@ async def chat_proxy(request: Request): response = async_gen.model_dump_json() prompt_tok = async_gen.prompt_eval_count or 0 comp_tok = async_gen.eval_count or 0 - record_token_usage(endpoint, model, prompt_tok, comp_tok) + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( response if hasattr(async_gen, "model_dump_json") - else json.dumps(async_gen) + else orjson.dumps(async_gen) ) yield json_line.encode("utf-8") + b"\n" @@ -882,7 +904,7 @@ async def embedding_proxy(request: Request): # 1. Parse and validate request try: body_bytes = await request.body() - payload = json.loads(body_bytes.decode("utf-8")) + payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") prompt = payload.get("prompt") @@ -897,7 +919,7 @@ async def embedding_proxy(request: Request): raise HTTPException( status_code=400, detail="Missing required field 'prompt'" ) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic @@ -923,7 +945,7 @@ async def embedding_proxy(request: Request): if hasattr(async_gen, "model_dump_json"): json_line = async_gen.model_dump_json() else: - json_line = json.dumps(async_gen) + json_line = orjson.dumps(async_gen) yield json_line.encode("utf-8") + b"\n" finally: # Ensure counter is decremented even if an exception occurs @@ -947,7 +969,7 @@ async def embed_proxy(request: Request): # 1. Parse and validate request try: body_bytes = await request.body() - payload = json.loads(body_bytes.decode("utf-8")) + payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") _input = payload.get("input") @@ -963,12 +985,12 @@ async def embed_proxy(request: Request): raise HTTPException( status_code=400, detail="Missing required field 'input'" ) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic endpoint = await choose_endpoint(model) - is_openai_endpoint = "/v1" in endpoint + is_openai_endpoint = is_ext_openai_endpoint(endpoint) #"/v1" in endpoint if is_openai_endpoint: if ":latest" in model: model = model.split(":latest") @@ -989,7 +1011,7 @@ async def embed_proxy(request: Request): if hasattr(async_gen, "model_dump_json"): json_line = async_gen.model_dump_json() else: - json_line = json.dumps(async_gen) + json_line = orjson.dumps(async_gen) yield json_line.encode("utf-8") + b"\n" finally: # Ensure counter is decremented even if an exception occurs @@ -1011,7 +1033,7 @@ async def create_proxy(request: Request): """ try: body_bytes = await request.body() - payload = json.loads(body_bytes.decode("utf-8")) + payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") quantize = payload.get("quantize") @@ -1032,7 +1054,7 @@ async def create_proxy(request: Request): raise HTTPException( status_code=400, detail="You need to provide either from_ or files parameter!" ) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e status_lists = [] @@ -1062,14 +1084,14 @@ async def show_proxy(request: Request, model: Optional[str] = None): body_bytes = await request.body() if not model: - payload = json.loads(body_bytes.decode("utf-8")) + payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic @@ -1097,7 +1119,7 @@ async def copy_proxy(request: Request, source: Optional[str] = None, destination body_bytes = await request.body() if not source and not destination: - payload = json.loads(body_bytes.decode("utf-8")) + payload = orjson.loads(body_bytes.decode("utf-8")) src = payload.get("source") dst = payload.get("destination") else: @@ -1112,7 +1134,7 @@ async def copy_proxy(request: Request, source: Optional[str] = None, destination raise HTTPException( status_code=400, detail="Missing required field 'destination'" ) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 3. Iterate over all endpoints to copy the model on each endpoint @@ -1141,14 +1163,14 @@ async def delete_proxy(request: Request, model: Optional[str] = None): body_bytes = await request.body() if not model: - payload = json.loads(body_bytes.decode("utf-8")) + payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Iterate over all endpoints to delete the model on each endpoint @@ -1176,7 +1198,7 @@ async def pull_proxy(request: Request, model: Optional[str] = None): body_bytes = await request.body() if not model: - payload = json.loads(body_bytes.decode("utf-8")) + payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") insecure = payload.get("insecure") else: @@ -1186,7 +1208,7 @@ async def pull_proxy(request: Request, model: Optional[str] = None): raise HTTPException( status_code=400, detail="Missing required field 'model'" ) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Iterate over all endpoints to pull the model @@ -1218,7 +1240,7 @@ async def push_proxy(request: Request): # 1. Parse and validate request try: body_bytes = await request.body() - payload = json.loads(body_bytes.decode("utf-8")) + payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") insecure = payload.get("insecure") @@ -1227,7 +1249,7 @@ async def push_proxy(request: Request): raise HTTPException( status_code=400, detail="Missing required field 'model'" ) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Iterate over all endpoints @@ -1385,11 +1407,10 @@ async def openai_embedding_proxy(request: Request): # 1. Parse and validate request try: body_bytes = await request.body() - payload = json.loads(body_bytes.decode("utf-8")) + payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") doc = payload.get("input") - if not model: raise HTTPException( @@ -1399,13 +1420,13 @@ async def openai_embedding_proxy(request: Request): raise HTTPException( status_code=400, detail="Missing required field 'input'" ) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic endpoint = await choose_endpoint(model) await increment_usage(endpoint, model) - if "/v1" in endpoint: + if "/v1" in endpoint: # and is_ext_openai_endpoint(endpoint): api_key = config.api_keys[endpoint] else: api_key = "ollama" @@ -1432,7 +1453,7 @@ async def openai_chat_completions_proxy(request: Request): # 1. Parse and validate request try: body_bytes = await request.body() - payload = json.loads(body_bytes.decode("utf-8")) + payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") messages = payload.get("messages") @@ -1483,7 +1504,7 @@ async def openai_chat_completions_proxy(request: Request): raise HTTPException( status_code=400, detail="Missing required field 'messages' (must be a list)" ) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic @@ -1501,7 +1522,7 @@ async def openai_chat_completions_proxy(request: Request): data = ( chunk.model_dump_json() if hasattr(chunk, "model_dump_json") - else json.dumps(chunk) + else orjson.dumps(chunk) ) if chunk.choices[0].delta.content is not None: yield f"data: {data}\n\n".encode("utf-8") @@ -1509,11 +1530,11 @@ async def openai_chat_completions_proxy(request: Request): else: prompt_tok = async_gen.usage.prompt_tokens or 0 comp_tok = async_gen.usage.completion_tokens or 0 - record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok) + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( async_gen.model_dump_json() if hasattr(async_gen, "model_dump_json") - else json.dumps(async_gen) + else orjson.dumps(async_gen) ) yield json_line.encode("utf-8") + b"\n" @@ -1539,7 +1560,7 @@ async def openai_completions_proxy(request: Request): # 1. Parse and validate request try: body_bytes = await request.body() - payload = json.loads(body_bytes.decode("utf-8")) + payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") prompt = payload.get("prompt") @@ -1588,7 +1609,7 @@ async def openai_completions_proxy(request: Request): raise HTTPException( status_code=400, detail="Missing required field 'prompt'" ) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic @@ -1607,7 +1628,7 @@ async def openai_completions_proxy(request: Request): data = ( chunk.model_dump_json() if hasattr(chunk, "model_dump_json") - else json.dumps(chunk) + else orjson.dumps(chunk) ) yield f"data: {data}\n\n".encode("utf-8") # Final DONE event @@ -1615,11 +1636,11 @@ async def openai_completions_proxy(request: Request): else: prompt_tok = async_gen.usage.prompt_tokens or 0 comp_tok = async_gen.usage.completion_tokens or 0 - record_token_usage(endpoint, payload.get("model"), prompt_tok, comp_tok) + await token_queue.put((endpoint, model, prompt_tok, comp_tok)) json_line = ( async_gen.model_dump_json() if hasattr(async_gen, "model_dump_json") - else json.dumps(async_gen) + else orjson.dumps(async_gen) ) yield json_line.encode("utf-8") + b"\n" @@ -1694,7 +1715,7 @@ async def health_proxy(request: Request): * The HTTP status code is 200 when everything is healthy, 503 otherwise. """ # Run all health checks in parallel - tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints] + tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints] # if not is_ext_openai_endpoint(ep)] results = await asyncio.gather(*tasks, return_exceptions=True) @@ -1774,9 +1795,11 @@ async def startup_event() -> None: app_state["connector"] = connector app_state["session"] = session - + token_worker_task = asyncio.create_task(token_worker()) @app.on_event("shutdown") async def shutdown_event() -> None: await close_all_sse_queues() await app_state["session"].close() + if token_worker_task is not None: + token_worker_task.cancel() \ No newline at end of file