refactor: make choose_endpoint use cache incrementer for atomic updates
This commit is contained in:
parent
e7196146ad
commit
e96e890511
1 changed files with 50 additions and 90 deletions
140
router.py
140
router.py
|
|
@ -928,11 +928,12 @@ async def decrement_usage(endpoint: str, model: str) -> None:
|
|||
# usage_counts.pop(endpoint, None)
|
||||
await publish_snapshot()
|
||||
|
||||
async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
|
||||
async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
|
||||
"""
|
||||
Helper function to make a chat request to a specific endpoint.
|
||||
Handles endpoint selection, client creation, usage tracking, and request execution.
|
||||
"""
|
||||
endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
|
|
@ -962,10 +963,6 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
|||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
|
||||
try:
|
||||
if use_openai:
|
||||
start_ts = time.perf_counter()
|
||||
|
|
@ -1057,18 +1054,11 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
|
|||
|
||||
moe_reqs = []
|
||||
|
||||
# Generate 3 responses
|
||||
response1_endpoint = await choose_endpoint(model)
|
||||
response1_task = asyncio.create_task(_make_chat_request(response1_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
response2_endpoint = await choose_endpoint(model)
|
||||
response2_task = asyncio.create_task(_make_chat_request(response2_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
response3_endpoint = await choose_endpoint(model)
|
||||
response3_task = asyncio.create_task(_make_chat_request(response3_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
# Generate 3 responses — choose_endpoint is called inside _make_chat_request and
|
||||
# atomically reserves a slot, so all 3 tasks see each other's load immediately.
|
||||
response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
|
||||
responses = await asyncio.gather(response1_task, response2_task, response3_task)
|
||||
|
||||
|
|
@ -1077,17 +1067,9 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
|
|||
moe_reqs.append(moe_req)
|
||||
|
||||
# Generate 3 critiques
|
||||
critique1_endpoint = await choose_endpoint(model)
|
||||
critique1_task = asyncio.create_task(_make_chat_request(critique1_endpoint, model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
critique2_endpoint = await choose_endpoint(model)
|
||||
critique2_task = asyncio.create_task(_make_chat_request(critique2_endpoint, model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
|
||||
critique3_endpoint = await choose_endpoint(model)
|
||||
critique3_task = asyncio.create_task(_make_chat_request(critique3_endpoint, model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
||||
critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
|
||||
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
|
||||
|
||||
|
|
@ -1095,8 +1077,7 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
|
|||
m = enhance.moe_select_candidate(query, critiques)
|
||||
|
||||
# Generate final response
|
||||
final_endpoint = await choose_endpoint(model)
|
||||
return await _make_chat_request(final_endpoint, model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
|
||||
def iso8601_ns():
|
||||
ns = time.time_ns()
|
||||
|
|
@ -1465,7 +1446,7 @@ async def get_usage_counts() -> Dict:
|
|||
# -------------------------------------------------------------
|
||||
# 5. Endpoint selection logic (respecting the configurable limit)
|
||||
# -------------------------------------------------------------
|
||||
async def choose_endpoint(model: str) -> str:
|
||||
async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
|
||||
"""
|
||||
Determine which endpoint to use for the given model while respecting
|
||||
the `max_concurrent_connections` per endpoint‑model pair **and**
|
||||
|
|
@ -1526,7 +1507,8 @@ async def choose_endpoint(model: str) -> str:
|
|||
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
|
||||
loaded_sets = await asyncio.gather(*load_tasks)
|
||||
|
||||
# Protect all reads of usage_counts with the lock
|
||||
# Protect all reads/writes of usage_counts with the lock so that selection
|
||||
# and reservation are atomic — concurrent callers see each other's pending load.
|
||||
async with usage_lock:
|
||||
# Helper: current usage for (endpoint, model) using the same normalized key
|
||||
# that increment_usage/decrement_usage store — raw model names differ from
|
||||
|
|
@ -1544,34 +1526,37 @@ async def choose_endpoint(model: str) -> str:
|
|||
# Sort ascending for load balancing — all endpoints here already have the
|
||||
# model loaded, so there is no model-switching cost to optimise for.
|
||||
loaded_and_free.sort(key=tracking_usage)
|
||||
|
||||
# When all candidates are equally idle, randomise to avoid always picking
|
||||
# the first entry in a stable sort.
|
||||
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
|
||||
return random.choice(loaded_and_free)
|
||||
selected = random.choice(loaded_and_free)
|
||||
else:
|
||||
selected = loaded_and_free[0]
|
||||
else:
|
||||
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
||||
endpoints_with_free_slot = [
|
||||
ep for ep in candidate_endpoints
|
||||
if tracking_usage(ep) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
return loaded_and_free[0]
|
||||
if endpoints_with_free_slot:
|
||||
# Sort by total endpoint load (ascending) to prefer idle endpoints.
|
||||
endpoints_with_free_slot.sort(
|
||||
key=lambda ep: sum(usage_counts.get(ep, {}).values())
|
||||
)
|
||||
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
|
||||
selected = random.choice(endpoints_with_free_slot)
|
||||
else:
|
||||
selected = endpoints_with_free_slot[0]
|
||||
else:
|
||||
# 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
|
||||
selected = min(candidate_endpoints, key=tracking_usage)
|
||||
|
||||
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
||||
endpoints_with_free_slot = [
|
||||
ep for ep in candidate_endpoints
|
||||
if tracking_usage(ep) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
if endpoints_with_free_slot:
|
||||
# Sort by total endpoint load (ascending) to prefer idle endpoints.
|
||||
endpoints_with_free_slot.sort(
|
||||
key=lambda ep: sum(usage_counts.get(ep, {}).values())
|
||||
)
|
||||
|
||||
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
|
||||
return random.choice(endpoints_with_free_slot)
|
||||
|
||||
return endpoints_with_free_slot[0]
|
||||
|
||||
# 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
|
||||
ep = min(candidate_endpoints, key=tracking_usage)
|
||||
return ep
|
||||
tracking_model = get_tracking_model(selected, model)
|
||||
if reserve:
|
||||
usage_counts[selected][tracking_model] += 1
|
||||
await publish_snapshot()
|
||||
return selected, tracking_model
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 6. API route – Generate
|
||||
|
|
@ -1612,10 +1597,8 @@ async def proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=error_msg) from e
|
||||
|
||||
|
||||
endpoint = await choose_endpoint(model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
|
|
@ -1640,7 +1623,6 @@ async def proxy(request: Request):
|
|||
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
|
||||
# 4. Async generator that streams data and decrements the counter
|
||||
async def stream_generate_response():
|
||||
|
|
@ -1735,10 +1717,8 @@ async def chat_proxy(request: Request):
|
|||
opt = True
|
||||
else:
|
||||
opt = False
|
||||
endpoint = await choose_endpoint(model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
|
|
@ -1769,7 +1749,6 @@ async def chat_proxy(request: Request):
|
|||
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
# 3. Async generator that streams chat data and decrements the counter
|
||||
async def stream_chat_response():
|
||||
try:
|
||||
|
|
@ -1861,10 +1840,8 @@ async def embedding_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
|
|
@ -1872,7 +1849,6 @@ async def embedding_proxy(request: Request):
|
|||
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
# 3. Async generator that streams embedding data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
|
|
@ -1929,10 +1905,8 @@ async def embed_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
|
|
@ -1940,7 +1914,6 @@ async def embed_proxy(request: Request):
|
|||
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
# 3. Async generator that streams embed data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
|
|
@ -2038,8 +2011,7 @@ async def show_proxy(request: Request, model: Optional[str] = None):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
#await increment_usage(endpoint, model)
|
||||
endpoint, _ = await choose_endpoint(model, reserve=False)
|
||||
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
|
|
@ -2628,10 +2600,7 @@ async def openai_embedding_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
if is_openai_compatible(endpoint):
|
||||
api_key = config.api_keys.get(endpoint, "no-key")
|
||||
else:
|
||||
|
|
@ -2722,10 +2691,7 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
# 3. Async generator that streams completions data and decrements the counter
|
||||
|
|
@ -2868,10 +2834,7 @@ async def openai_completions_proxy(request: Request):
|
|||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
|
||||
|
|
@ -3056,12 +3019,13 @@ async def rerank_proxy(request: Request):
|
|||
|
||||
# Determine which endpoint serves this model
|
||||
try:
|
||||
endpoint = await choose_endpoint(model)
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
# Ollama endpoints have no native rerank support
|
||||
if not is_openai_compatible(endpoint):
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=(
|
||||
|
|
@ -3071,13 +3035,9 @@ async def rerank_proxy(request: Request):
|
|||
),
|
||||
)
|
||||
|
||||
# Normalize model name for tracking
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")[0]
|
||||
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
|
||||
# Build upstream rerank request body – forward only recognised fields
|
||||
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
|
||||
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue