refactor: make choose_endpoint use cache incrementer for atomic updates

This commit is contained in:
Alpha Nerd 2026-03-03 14:57:37 +01:00
parent e7196146ad
commit e96e890511

118
router.py
View file

@ -928,11 +928,12 @@ async def decrement_usage(endpoint: str, model: str) -> None:
# usage_counts.pop(endpoint, None) # usage_counts.pop(endpoint, None)
await publish_snapshot() await publish_snapshot()
async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse: async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
""" """
Helper function to make a chat request to a specific endpoint. Helper function to make a chat request to a specific endpoint.
Handles endpoint selection, client creation, usage tracking, and request execution. Handles endpoint selection, client creation, usage tracking, and request execution.
""" """
endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves
use_openai = is_openai_compatible(endpoint) use_openai = is_openai_compatible(endpoint)
if use_openai: if use_openai:
if ":latest" in model: if ":latest" in model:
@ -962,10 +963,6 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
else: else:
client = ollama.AsyncClient(host=endpoint) client = ollama.AsyncClient(host=endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
try: try:
if use_openai: if use_openai:
start_ts = time.perf_counter() start_ts = time.perf_counter()
@ -1057,18 +1054,11 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
moe_reqs = [] moe_reqs = []
# Generate 3 responses # Generate 3 responses — choose_endpoint is called inside _make_chat_request and
response1_endpoint = await choose_endpoint(model) # atomically reserves a slot, so all 3 tasks see each other's load immediately.
response1_task = asyncio.create_task(_make_chat_request(response1_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response2_endpoint = await choose_endpoint(model)
response2_task = asyncio.create_task(_make_chat_request(response2_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
response3_endpoint = await choose_endpoint(model)
response3_task = asyncio.create_task(_make_chat_request(response3_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
responses = await asyncio.gather(response1_task, response2_task, response3_task) responses = await asyncio.gather(response1_task, response2_task, response3_task)
@ -1077,17 +1067,9 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
moe_reqs.append(moe_req) moe_reqs.append(moe_req)
# Generate 3 critiques # Generate 3 critiques
critique1_endpoint = await choose_endpoint(model) critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique1_task = asyncio.create_task(_make_chat_request(critique1_endpoint, model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)) critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique2_endpoint = await choose_endpoint(model)
critique2_task = asyncio.create_task(_make_chat_request(critique2_endpoint, model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
critique3_endpoint = await choose_endpoint(model)
critique3_task = asyncio.create_task(_make_chat_request(critique3_endpoint, model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task) critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
@ -1095,8 +1077,7 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
m = enhance.moe_select_candidate(query, critiques) m = enhance.moe_select_candidate(query, critiques)
# Generate final response # Generate final response
final_endpoint = await choose_endpoint(model) return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
return await _make_chat_request(final_endpoint, model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
def iso8601_ns(): def iso8601_ns():
ns = time.time_ns() ns = time.time_ns()
@ -1465,7 +1446,7 @@ async def get_usage_counts() -> Dict:
# ------------------------------------------------------------- # -------------------------------------------------------------
# 5. Endpoint selection logic (respecting the configurable limit) # 5. Endpoint selection logic (respecting the configurable limit)
# ------------------------------------------------------------- # -------------------------------------------------------------
async def choose_endpoint(model: str) -> str: async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
""" """
Determine which endpoint to use for the given model while respecting Determine which endpoint to use for the given model while respecting
the `max_concurrent_connections` per endpointmodel pair **and** the `max_concurrent_connections` per endpointmodel pair **and**
@ -1526,7 +1507,8 @@ async def choose_endpoint(model: str) -> str:
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks) loaded_sets = await asyncio.gather(*load_tasks)
# Protect all reads of usage_counts with the lock # Protect all reads/writes of usage_counts with the lock so that selection
# and reservation are atomic — concurrent callers see each other's pending load.
async with usage_lock: async with usage_lock:
# Helper: current usage for (endpoint, model) using the same normalized key # Helper: current usage for (endpoint, model) using the same normalized key
# that increment_usage/decrement_usage store — raw model names differ from # that increment_usage/decrement_usage store — raw model names differ from
@ -1544,14 +1526,13 @@ async def choose_endpoint(model: str) -> str:
# Sort ascending for load balancing — all endpoints here already have the # Sort ascending for load balancing — all endpoints here already have the
# model loaded, so there is no model-switching cost to optimise for. # model loaded, so there is no model-switching cost to optimise for.
loaded_and_free.sort(key=tracking_usage) loaded_and_free.sort(key=tracking_usage)
# When all candidates are equally idle, randomise to avoid always picking # When all candidates are equally idle, randomise to avoid always picking
# the first entry in a stable sort. # the first entry in a stable sort.
if all(tracking_usage(ep) == 0 for ep in loaded_and_free): if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
return random.choice(loaded_and_free) selected = random.choice(loaded_and_free)
else:
return loaded_and_free[0] selected = loaded_and_free[0]
else:
# 4⃣ Endpoints among the candidates that simply have a free slot # 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [ endpoints_with_free_slot = [
ep for ep in candidate_endpoints ep for ep in candidate_endpoints
@ -1563,15 +1544,19 @@ async def choose_endpoint(model: str) -> str:
endpoints_with_free_slot.sort( endpoints_with_free_slot.sort(
key=lambda ep: sum(usage_counts.get(ep, {}).values()) key=lambda ep: sum(usage_counts.get(ep, {}).values())
) )
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
return random.choice(endpoints_with_free_slot) selected = random.choice(endpoints_with_free_slot)
else:
return endpoints_with_free_slot[0] selected = endpoints_with_free_slot[0]
else:
# 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue) # 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue)
ep = min(candidate_endpoints, key=tracking_usage) selected = min(candidate_endpoints, key=tracking_usage)
return ep
tracking_model = get_tracking_model(selected, model)
if reserve:
usage_counts[selected][tracking_model] += 1
await publish_snapshot()
return selected, tracking_model
# ------------------------------------------------------------- # -------------------------------------------------------------
# 6. API route Generate # 6. API route Generate
@ -1612,10 +1597,8 @@ async def proxy(request: Request):
raise HTTPException(status_code=400, detail=error_msg) from e raise HTTPException(status_code=400, detail=error_msg) from e
endpoint = await choose_endpoint(model) endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint) use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai: if use_openai:
if ":latest" in model: if ":latest" in model:
model = model.split(":latest") model = model.split(":latest")
@ -1640,7 +1623,6 @@ async def proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else: else:
client = ollama.AsyncClient(host=endpoint) client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, tracking_model)
# 4. Async generator that streams data and decrements the counter # 4. Async generator that streams data and decrements the counter
async def stream_generate_response(): async def stream_generate_response():
@ -1735,10 +1717,8 @@ async def chat_proxy(request: Request):
opt = True opt = True
else: else:
opt = False opt = False
endpoint = await choose_endpoint(model) endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint) use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai: if use_openai:
if ":latest" in model: if ":latest" in model:
model = model.split(":latest") model = model.split(":latest")
@ -1769,7 +1749,6 @@ async def chat_proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else: else:
client = ollama.AsyncClient(host=endpoint) client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams chat data and decrements the counter # 3. Async generator that streams chat data and decrements the counter
async def stream_chat_response(): async def stream_chat_response():
try: try:
@ -1861,10 +1840,8 @@ async def embedding_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic # 2. Endpoint logic
endpoint = await choose_endpoint(model) endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint) use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai: if use_openai:
if ":latest" in model: if ":latest" in model:
model = model.split(":latest") model = model.split(":latest")
@ -1872,7 +1849,6 @@ async def embedding_proxy(request: Request):
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
else: else:
client = ollama.AsyncClient(host=endpoint) client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams embedding data and decrements the counter # 3. Async generator that streams embedding data and decrements the counter
async def stream_embedding_response(): async def stream_embedding_response():
try: try:
@ -1929,10 +1905,8 @@ async def embed_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic # 2. Endpoint logic
endpoint = await choose_endpoint(model) endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint) use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai: if use_openai:
if ":latest" in model: if ":latest" in model:
model = model.split(":latest") model = model.split(":latest")
@ -1940,7 +1914,6 @@ async def embed_proxy(request: Request):
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
else: else:
client = ollama.AsyncClient(host=endpoint) client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams embed data and decrements the counter # 3. Async generator that streams embed data and decrements the counter
async def stream_embedding_response(): async def stream_embedding_response():
try: try:
@ -2038,8 +2011,7 @@ async def show_proxy(request: Request, model: Optional[str] = None):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic # 2. Endpoint logic
endpoint = await choose_endpoint(model) endpoint, _ = await choose_endpoint(model, reserve=False)
#await increment_usage(endpoint, model)
client = ollama.AsyncClient(host=endpoint) client = ollama.AsyncClient(host=endpoint)
@ -2628,10 +2600,7 @@ async def openai_embedding_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic # 2. Endpoint logic
endpoint = await choose_endpoint(model) endpoint, tracking_model = await choose_endpoint(model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
if is_openai_compatible(endpoint): if is_openai_compatible(endpoint):
api_key = config.api_keys.get(endpoint, "no-key") api_key = config.api_keys.get(endpoint, "no-key")
else: else:
@ -2722,10 +2691,7 @@ async def openai_chat_completions_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic # 2. Endpoint logic
endpoint = await choose_endpoint(model) endpoint, tracking_model = await choose_endpoint(model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
base_url = ep2base(endpoint) base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Async generator that streams completions data and decrements the counter # 3. Async generator that streams completions data and decrements the counter
@ -2868,10 +2834,7 @@ async def openai_completions_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic # 2. Endpoint logic
endpoint = await choose_endpoint(model) endpoint, tracking_model = await choose_endpoint(model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
base_url = ep2base(endpoint) base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
@ -3056,12 +3019,13 @@ async def rerank_proxy(request: Request):
# Determine which endpoint serves this model # Determine which endpoint serves this model
try: try:
endpoint = await choose_endpoint(model) endpoint, tracking_model = await choose_endpoint(model)
except RuntimeError as e: except RuntimeError as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
# Ollama endpoints have no native rerank support # Ollama endpoints have no native rerank support
if not is_openai_compatible(endpoint): if not is_openai_compatible(endpoint):
await decrement_usage(endpoint, tracking_model)
raise HTTPException( raise HTTPException(
status_code=501, status_code=501,
detail=( detail=(
@ -3071,13 +3035,9 @@ async def rerank_proxy(request: Request):
), ),
) )
# Normalize model name for tracking
tracking_model = get_tracking_model(endpoint, model)
if ":latest" in model: if ":latest" in model:
model = model.split(":latest")[0] model = model.split(":latest")[0]
await increment_usage(endpoint, tracking_model)
# Build upstream rerank request body forward only recognised fields # Build upstream rerank request body forward only recognised fields
upstream_payload: dict = {"model": model, "query": query, "documents": documents} upstream_payload: dict = {"model": model, "query": query, "documents": documents}
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"): for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):