diff --git a/router.py b/router.py index b817eac..f823b01 100644 --- a/router.py +++ b/router.py @@ -928,11 +928,12 @@ async def decrement_usage(endpoint: str, model: str) -> None: # usage_counts.pop(endpoint, None) await publish_snapshot() -async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse: +async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse: """ Helper function to make a chat request to a specific endpoint. Handles endpoint selection, client creation, usage tracking, and request execution. """ + endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves use_openai = is_openai_compatible(endpoint) if use_openai: if ":latest" in model: @@ -962,10 +963,6 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No else: client = ollama.AsyncClient(host=endpoint) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) - await increment_usage(endpoint, tracking_model) - try: if use_openai: start_ts = time.perf_counter() @@ -1057,18 +1054,11 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool moe_reqs = [] - # Generate 3 responses - response1_endpoint = await choose_endpoint(model) - response1_task = asyncio.create_task(_make_chat_request(response1_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) - await asyncio.sleep(0.01) # Small delay to allow usage count to update - - response2_endpoint = await choose_endpoint(model) - response2_task = asyncio.create_task(_make_chat_request(response2_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) - await asyncio.sleep(0.01) # Small delay to allow usage count to update - - response3_endpoint = await choose_endpoint(model) - response3_task = asyncio.create_task(_make_chat_request(response3_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) - await asyncio.sleep(0.01) # Small delay to allow usage count to update + # Generate 3 responses — choose_endpoint is called inside _make_chat_request and + # atomically reserves a slot, so all 3 tasks see each other's load immediately. + response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) responses = await asyncio.gather(response1_task, response2_task, response3_task) @@ -1077,17 +1067,9 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool moe_reqs.append(moe_req) # Generate 3 critiques - critique1_endpoint = await choose_endpoint(model) - critique1_task = asyncio.create_task(_make_chat_request(critique1_endpoint, model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) - await asyncio.sleep(0.01) # Small delay to allow usage count to update - - critique2_endpoint = await choose_endpoint(model) - critique2_task = asyncio.create_task(_make_chat_request(critique2_endpoint, model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) - await asyncio.sleep(0.01) # Small delay to allow usage count to update - - critique3_endpoint = await choose_endpoint(model) - critique3_task = asyncio.create_task(_make_chat_request(critique3_endpoint, model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) - await asyncio.sleep(0.01) # Small delay to allow usage count to update + critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) + critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task) @@ -1095,8 +1077,7 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool m = enhance.moe_select_candidate(query, critiques) # Generate final response - final_endpoint = await choose_endpoint(model) - return await _make_chat_request(final_endpoint, model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive) + return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive) def iso8601_ns(): ns = time.time_ns() @@ -1465,7 +1446,7 @@ async def get_usage_counts() -> Dict: # ------------------------------------------------------------- # 5. Endpoint selection logic (respecting the configurable limit) # ------------------------------------------------------------- -async def choose_endpoint(model: str) -> str: +async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: """ Determine which endpoint to use for the given model while respecting the `max_concurrent_connections` per endpoint‑model pair **and** @@ -1526,7 +1507,8 @@ async def choose_endpoint(model: str) -> str: load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] loaded_sets = await asyncio.gather(*load_tasks) - # Protect all reads of usage_counts with the lock + # Protect all reads/writes of usage_counts with the lock so that selection + # and reservation are atomic — concurrent callers see each other's pending load. async with usage_lock: # Helper: current usage for (endpoint, model) using the same normalized key # that increment_usage/decrement_usage store — raw model names differ from @@ -1544,34 +1526,37 @@ async def choose_endpoint(model: str) -> str: # Sort ascending for load balancing — all endpoints here already have the # model loaded, so there is no model-switching cost to optimise for. loaded_and_free.sort(key=tracking_usage) - # When all candidates are equally idle, randomise to avoid always picking # the first entry in a stable sort. if all(tracking_usage(ep) == 0 for ep in loaded_and_free): - return random.choice(loaded_and_free) + selected = random.choice(loaded_and_free) + else: + selected = loaded_and_free[0] + else: + # 4️⃣ Endpoints among the candidates that simply have a free slot + endpoints_with_free_slot = [ + ep for ep in candidate_endpoints + if tracking_usage(ep) < config.max_concurrent_connections + ] - return loaded_and_free[0] + if endpoints_with_free_slot: + # Sort by total endpoint load (ascending) to prefer idle endpoints. + endpoints_with_free_slot.sort( + key=lambda ep: sum(usage_counts.get(ep, {}).values()) + ) + if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): + selected = random.choice(endpoints_with_free_slot) + else: + selected = endpoints_with_free_slot[0] + else: + # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue) + selected = min(candidate_endpoints, key=tracking_usage) - # 4️⃣ Endpoints among the candidates that simply have a free slot - endpoints_with_free_slot = [ - ep for ep in candidate_endpoints - if tracking_usage(ep) < config.max_concurrent_connections - ] - - if endpoints_with_free_slot: - # Sort by total endpoint load (ascending) to prefer idle endpoints. - endpoints_with_free_slot.sort( - key=lambda ep: sum(usage_counts.get(ep, {}).values()) - ) - - if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): - return random.choice(endpoints_with_free_slot) - - return endpoints_with_free_slot[0] - - # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue) - ep = min(candidate_endpoints, key=tracking_usage) - return ep + tracking_model = get_tracking_model(selected, model) + if reserve: + usage_counts[selected][tracking_model] += 1 + await publish_snapshot() + return selected, tracking_model # ------------------------------------------------------------- # 6. API route – Generate @@ -1612,10 +1597,8 @@ async def proxy(request: Request): raise HTTPException(status_code=400, detail=error_msg) from e - endpoint = await choose_endpoint(model) + endpoint, tracking_model = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1640,7 +1623,6 @@ async def proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, tracking_model) # 4. Async generator that streams data and decrements the counter async def stream_generate_response(): @@ -1735,10 +1717,8 @@ async def chat_proxy(request: Request): opt = True else: opt = False - endpoint = await choose_endpoint(model) + endpoint, tracking_model = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1769,7 +1749,6 @@ async def chat_proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, tracking_model) # 3. Async generator that streams chat data and decrements the counter async def stream_chat_response(): try: @@ -1861,10 +1840,8 @@ async def embedding_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) + endpoint, tracking_model = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1872,7 +1849,6 @@ async def embedding_proxy(request: Request): client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, tracking_model) # 3. Async generator that streams embedding data and decrements the counter async def stream_embedding_response(): try: @@ -1929,10 +1905,8 @@ async def embed_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) + endpoint, tracking_model = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1940,7 +1914,6 @@ async def embed_proxy(request: Request): client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, tracking_model) # 3. Async generator that streams embed data and decrements the counter async def stream_embedding_response(): try: @@ -2038,8 +2011,7 @@ async def show_proxy(request: Request, model: Optional[str] = None): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) - #await increment_usage(endpoint, model) + endpoint, _ = await choose_endpoint(model, reserve=False) client = ollama.AsyncClient(host=endpoint) @@ -2628,10 +2600,7 @@ async def openai_embedding_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) - await increment_usage(endpoint, tracking_model) + endpoint, tracking_model = await choose_endpoint(model) if is_openai_compatible(endpoint): api_key = config.api_keys.get(endpoint, "no-key") else: @@ -2722,10 +2691,7 @@ async def openai_chat_completions_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) - await increment_usage(endpoint, tracking_model) + endpoint, tracking_model = await choose_endpoint(model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Async generator that streams completions data and decrements the counter @@ -2868,10 +2834,7 @@ async def openai_completions_proxy(request: Request): raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic - endpoint = await choose_endpoint(model) - # Normalize model name for tracking so it matches the PS table key - tracking_model = get_tracking_model(endpoint, model) - await increment_usage(endpoint, tracking_model) + endpoint, tracking_model = await choose_endpoint(model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) @@ -3056,12 +3019,13 @@ async def rerank_proxy(request: Request): # Determine which endpoint serves this model try: - endpoint = await choose_endpoint(model) + endpoint, tracking_model = await choose_endpoint(model) except RuntimeError as e: raise HTTPException(status_code=404, detail=str(e)) # Ollama endpoints have no native rerank support if not is_openai_compatible(endpoint): + await decrement_usage(endpoint, tracking_model) raise HTTPException( status_code=501, detail=( @@ -3071,13 +3035,9 @@ async def rerank_proxy(request: Request): ), ) - # Normalize model name for tracking - tracking_model = get_tracking_model(endpoint, model) if ":latest" in model: model = model.split(":latest")[0] - await increment_usage(endpoint, tracking_model) - # Build upstream rerank request body – forward only recognised fields upstream_payload: dict = {"model": model, "query": query, "documents": documents} for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):