refactor: make choose_endpoint use cache incrementer for atomic updates

This commit is contained in:
Alpha Nerd 2026-03-03 14:57:37 +01:00
parent e7196146ad
commit e96e890511

140
router.py
View file

@ -928,11 +928,12 @@ async def decrement_usage(endpoint: str, model: str) -> None:
# usage_counts.pop(endpoint, None)
await publish_snapshot()
async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
Helper function to make a chat request to a specific endpoint.
Handles endpoint selection, client creation, usage tracking, and request execution.
"""
endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
@ -962,10 +963,6 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
else:
client = ollama.AsyncClient(host=endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
try:
if use_openai:
start_ts = time.perf_counter()
@ -1057,18 +1054,11 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
moe_reqs = []
# Generate 3 responses
response1_endpoint = await choose_endpoint(model)
response1_task = asyncio.create_task(_make_chat_request(response1_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
response2_endpoint = await choose_endpoint(model)
response2_task = asyncio.create_task(_make_chat_request(response2_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
response3_endpoint = await choose_endpoint(model)
response3_task = asyncio.create_task(_make_chat_request(response3_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
# Generate 3 responses — choose_endpoint is called inside _make_chat_request and
# atomically reserves a slot, so all 3 tasks see each other's load immediately.
response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
responses = await asyncio.gather(response1_task, response2_task, response3_task)
@ -1077,17 +1067,9 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
moe_reqs.append(moe_req)
# Generate 3 critiques
critique1_endpoint = await choose_endpoint(model)
critique1_task = asyncio.create_task(_make_chat_request(critique1_endpoint, model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
critique2_endpoint = await choose_endpoint(model)
critique2_task = asyncio.create_task(_make_chat_request(critique2_endpoint, model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
critique3_endpoint = await choose_endpoint(model)
critique3_task = asyncio.create_task(_make_chat_request(critique3_endpoint, model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
await asyncio.sleep(0.01) # Small delay to allow usage count to update
critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
@ -1095,8 +1077,7 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
m = enhance.moe_select_candidate(query, critiques)
# Generate final response
final_endpoint = await choose_endpoint(model)
return await _make_chat_request(final_endpoint, model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
def iso8601_ns():
ns = time.time_ns()
@ -1465,7 +1446,7 @@ async def get_usage_counts() -> Dict:
# -------------------------------------------------------------
# 5. Endpoint selection logic (respecting the configurable limit)
# -------------------------------------------------------------
async def choose_endpoint(model: str) -> str:
async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
"""
Determine which endpoint to use for the given model while respecting
the `max_concurrent_connections` per endpointmodel pair **and**
@ -1526,7 +1507,8 @@ async def choose_endpoint(model: str) -> str:
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks)
# Protect all reads of usage_counts with the lock
# Protect all reads/writes of usage_counts with the lock so that selection
# and reservation are atomic — concurrent callers see each other's pending load.
async with usage_lock:
# Helper: current usage for (endpoint, model) using the same normalized key
# that increment_usage/decrement_usage store — raw model names differ from
@ -1544,34 +1526,37 @@ async def choose_endpoint(model: str) -> str:
# Sort ascending for load balancing — all endpoints here already have the
# model loaded, so there is no model-switching cost to optimise for.
loaded_and_free.sort(key=tracking_usage)
# When all candidates are equally idle, randomise to avoid always picking
# the first entry in a stable sort.
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
return random.choice(loaded_and_free)
selected = random.choice(loaded_and_free)
else:
selected = loaded_and_free[0]
else:
# 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [
ep for ep in candidate_endpoints
if tracking_usage(ep) < config.max_concurrent_connections
]
return loaded_and_free[0]
if endpoints_with_free_slot:
# Sort by total endpoint load (ascending) to prefer idle endpoints.
endpoints_with_free_slot.sort(
key=lambda ep: sum(usage_counts.get(ep, {}).values())
)
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
selected = random.choice(endpoints_with_free_slot)
else:
selected = endpoints_with_free_slot[0]
else:
# 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue)
selected = min(candidate_endpoints, key=tracking_usage)
# 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [
ep for ep in candidate_endpoints
if tracking_usage(ep) < config.max_concurrent_connections
]
if endpoints_with_free_slot:
# Sort by total endpoint load (ascending) to prefer idle endpoints.
endpoints_with_free_slot.sort(
key=lambda ep: sum(usage_counts.get(ep, {}).values())
)
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
return random.choice(endpoints_with_free_slot)
return endpoints_with_free_slot[0]
# 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue)
ep = min(candidate_endpoints, key=tracking_usage)
return ep
tracking_model = get_tracking_model(selected, model)
if reserve:
usage_counts[selected][tracking_model] += 1
await publish_snapshot()
return selected, tracking_model
# -------------------------------------------------------------
# 6. API route Generate
@ -1612,10 +1597,8 @@ async def proxy(request: Request):
raise HTTPException(status_code=400, detail=error_msg) from e
endpoint = await choose_endpoint(model)
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1640,7 +1623,6 @@ async def proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, tracking_model)
# 4. Async generator that streams data and decrements the counter
async def stream_generate_response():
@ -1735,10 +1717,8 @@ async def chat_proxy(request: Request):
opt = True
else:
opt = False
endpoint = await choose_endpoint(model)
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1769,7 +1749,6 @@ async def chat_proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams chat data and decrements the counter
async def stream_chat_response():
try:
@ -1861,10 +1840,8 @@ async def embedding_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1872,7 +1849,6 @@ async def embedding_proxy(request: Request):
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams embedding data and decrements the counter
async def stream_embedding_response():
try:
@ -1929,10 +1905,8 @@ async def embed_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1940,7 +1914,6 @@ async def embed_proxy(request: Request):
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams embed data and decrements the counter
async def stream_embedding_response():
try:
@ -2038,8 +2011,7 @@ async def show_proxy(request: Request, model: Optional[str] = None):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
#await increment_usage(endpoint, model)
endpoint, _ = await choose_endpoint(model, reserve=False)
client = ollama.AsyncClient(host=endpoint)
@ -2628,10 +2600,7 @@ async def openai_embedding_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
endpoint, tracking_model = await choose_endpoint(model)
if is_openai_compatible(endpoint):
api_key = config.api_keys.get(endpoint, "no-key")
else:
@ -2722,10 +2691,7 @@ async def openai_chat_completions_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
endpoint, tracking_model = await choose_endpoint(model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Async generator that streams completions data and decrements the counter
@ -2868,10 +2834,7 @@ async def openai_completions_proxy(request: Request):
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
endpoint, tracking_model = await choose_endpoint(model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
@ -3056,12 +3019,13 @@ async def rerank_proxy(request: Request):
# Determine which endpoint serves this model
try:
endpoint = await choose_endpoint(model)
endpoint, tracking_model = await choose_endpoint(model)
except RuntimeError as e:
raise HTTPException(status_code=404, detail=str(e))
# Ollama endpoints have no native rerank support
if not is_openai_compatible(endpoint):
await decrement_usage(endpoint, tracking_model)
raise HTTPException(
status_code=501,
detail=(
@ -3071,13 +3035,9 @@ async def rerank_proxy(request: Request):
),
)
# Normalize model name for tracking
tracking_model = get_tracking_model(endpoint, model)
if ":latest" in model:
model = model.split(":latest")[0]
await increment_usage(endpoint, tracking_model)
# Build upstream rerank request body forward only recognised fields
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):