refactor: make choose_endpoint use cache incrementer for atomic updates
This commit is contained in:
parent
e7196146ad
commit
e96e890511
1 changed files with 50 additions and 90 deletions
118
router.py
118
router.py
|
|
@ -928,11 +928,12 @@ async def decrement_usage(endpoint: str, model: str) -> None:
|
||||||
# usage_counts.pop(endpoint, None)
|
# usage_counts.pop(endpoint, None)
|
||||||
await publish_snapshot()
|
await publish_snapshot()
|
||||||
|
|
||||||
async def _make_chat_request(endpoint: str, model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
|
async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
|
||||||
"""
|
"""
|
||||||
Helper function to make a chat request to a specific endpoint.
|
Helper function to make a chat request to a specific endpoint.
|
||||||
Handles endpoint selection, client creation, usage tracking, and request execution.
|
Handles endpoint selection, client creation, usage tracking, and request execution.
|
||||||
"""
|
"""
|
||||||
|
endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves
|
||||||
use_openai = is_openai_compatible(endpoint)
|
use_openai = is_openai_compatible(endpoint)
|
||||||
if use_openai:
|
if use_openai:
|
||||||
if ":latest" in model:
|
if ":latest" in model:
|
||||||
|
|
@ -962,10 +963,6 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
||||||
else:
|
else:
|
||||||
client = ollama.AsyncClient(host=endpoint)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
|
|
||||||
# Normalize model name for tracking so it matches the PS table key
|
|
||||||
tracking_model = get_tracking_model(endpoint, model)
|
|
||||||
await increment_usage(endpoint, tracking_model)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if use_openai:
|
if use_openai:
|
||||||
start_ts = time.perf_counter()
|
start_ts = time.perf_counter()
|
||||||
|
|
@ -1057,18 +1054,11 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
|
||||||
|
|
||||||
moe_reqs = []
|
moe_reqs = []
|
||||||
|
|
||||||
# Generate 3 responses
|
# Generate 3 responses — choose_endpoint is called inside _make_chat_request and
|
||||||
response1_endpoint = await choose_endpoint(model)
|
# atomically reserves a slot, so all 3 tasks see each other's load immediately.
|
||||||
response1_task = asyncio.create_task(_make_chat_request(response1_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||||
|
response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||||
response2_endpoint = await choose_endpoint(model)
|
|
||||||
response2_task = asyncio.create_task(_make_chat_request(response2_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
|
||||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
|
||||||
|
|
||||||
response3_endpoint = await choose_endpoint(model)
|
|
||||||
response3_task = asyncio.create_task(_make_chat_request(response3_endpoint, model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
|
||||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
|
||||||
|
|
||||||
responses = await asyncio.gather(response1_task, response2_task, response3_task)
|
responses = await asyncio.gather(response1_task, response2_task, response3_task)
|
||||||
|
|
||||||
|
|
@ -1077,17 +1067,9 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
|
||||||
moe_reqs.append(moe_req)
|
moe_reqs.append(moe_req)
|
||||||
|
|
||||||
# Generate 3 critiques
|
# Generate 3 critiques
|
||||||
critique1_endpoint = await choose_endpoint(model)
|
critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||||
critique1_task = asyncio.create_task(_make_chat_request(critique1_endpoint, model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||||
|
|
||||||
critique2_endpoint = await choose_endpoint(model)
|
|
||||||
critique2_task = asyncio.create_task(_make_chat_request(critique2_endpoint, model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
|
||||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
|
||||||
|
|
||||||
critique3_endpoint = await choose_endpoint(model)
|
|
||||||
critique3_task = asyncio.create_task(_make_chat_request(critique3_endpoint, model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
|
||||||
await asyncio.sleep(0.01) # Small delay to allow usage count to update
|
|
||||||
|
|
||||||
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
|
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
|
||||||
|
|
||||||
|
|
@ -1095,8 +1077,7 @@ async def _make_moe_requests(model: str, messages: list, tools=None, think: bool
|
||||||
m = enhance.moe_select_candidate(query, critiques)
|
m = enhance.moe_select_candidate(query, critiques)
|
||||||
|
|
||||||
# Generate final response
|
# Generate final response
|
||||||
final_endpoint = await choose_endpoint(model)
|
return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||||
return await _make_chat_request(final_endpoint, model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
|
|
||||||
|
|
||||||
def iso8601_ns():
|
def iso8601_ns():
|
||||||
ns = time.time_ns()
|
ns = time.time_ns()
|
||||||
|
|
@ -1465,7 +1446,7 @@ async def get_usage_counts() -> Dict:
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
# 5. Endpoint selection logic (respecting the configurable limit)
|
# 5. Endpoint selection logic (respecting the configurable limit)
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
async def choose_endpoint(model: str) -> str:
|
async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Determine which endpoint to use for the given model while respecting
|
Determine which endpoint to use for the given model while respecting
|
||||||
the `max_concurrent_connections` per endpoint‑model pair **and**
|
the `max_concurrent_connections` per endpoint‑model pair **and**
|
||||||
|
|
@ -1526,7 +1507,8 @@ async def choose_endpoint(model: str) -> str:
|
||||||
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
|
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
|
||||||
loaded_sets = await asyncio.gather(*load_tasks)
|
loaded_sets = await asyncio.gather(*load_tasks)
|
||||||
|
|
||||||
# Protect all reads of usage_counts with the lock
|
# Protect all reads/writes of usage_counts with the lock so that selection
|
||||||
|
# and reservation are atomic — concurrent callers see each other's pending load.
|
||||||
async with usage_lock:
|
async with usage_lock:
|
||||||
# Helper: current usage for (endpoint, model) using the same normalized key
|
# Helper: current usage for (endpoint, model) using the same normalized key
|
||||||
# that increment_usage/decrement_usage store — raw model names differ from
|
# that increment_usage/decrement_usage store — raw model names differ from
|
||||||
|
|
@ -1544,14 +1526,13 @@ async def choose_endpoint(model: str) -> str:
|
||||||
# Sort ascending for load balancing — all endpoints here already have the
|
# Sort ascending for load balancing — all endpoints here already have the
|
||||||
# model loaded, so there is no model-switching cost to optimise for.
|
# model loaded, so there is no model-switching cost to optimise for.
|
||||||
loaded_and_free.sort(key=tracking_usage)
|
loaded_and_free.sort(key=tracking_usage)
|
||||||
|
|
||||||
# When all candidates are equally idle, randomise to avoid always picking
|
# When all candidates are equally idle, randomise to avoid always picking
|
||||||
# the first entry in a stable sort.
|
# the first entry in a stable sort.
|
||||||
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
|
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
|
||||||
return random.choice(loaded_and_free)
|
selected = random.choice(loaded_and_free)
|
||||||
|
else:
|
||||||
return loaded_and_free[0]
|
selected = loaded_and_free[0]
|
||||||
|
else:
|
||||||
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
||||||
endpoints_with_free_slot = [
|
endpoints_with_free_slot = [
|
||||||
ep for ep in candidate_endpoints
|
ep for ep in candidate_endpoints
|
||||||
|
|
@ -1563,15 +1544,19 @@ async def choose_endpoint(model: str) -> str:
|
||||||
endpoints_with_free_slot.sort(
|
endpoints_with_free_slot.sort(
|
||||||
key=lambda ep: sum(usage_counts.get(ep, {}).values())
|
key=lambda ep: sum(usage_counts.get(ep, {}).values())
|
||||||
)
|
)
|
||||||
|
|
||||||
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
|
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
|
||||||
return random.choice(endpoints_with_free_slot)
|
selected = random.choice(endpoints_with_free_slot)
|
||||||
|
else:
|
||||||
return endpoints_with_free_slot[0]
|
selected = endpoints_with_free_slot[0]
|
||||||
|
else:
|
||||||
# 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
|
# 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
|
||||||
ep = min(candidate_endpoints, key=tracking_usage)
|
selected = min(candidate_endpoints, key=tracking_usage)
|
||||||
return ep
|
|
||||||
|
tracking_model = get_tracking_model(selected, model)
|
||||||
|
if reserve:
|
||||||
|
usage_counts[selected][tracking_model] += 1
|
||||||
|
await publish_snapshot()
|
||||||
|
return selected, tracking_model
|
||||||
|
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
# 6. API route – Generate
|
# 6. API route – Generate
|
||||||
|
|
@ -1612,10 +1597,8 @@ async def proxy(request: Request):
|
||||||
raise HTTPException(status_code=400, detail=error_msg) from e
|
raise HTTPException(status_code=400, detail=error_msg) from e
|
||||||
|
|
||||||
|
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint, tracking_model = await choose_endpoint(model)
|
||||||
use_openai = is_openai_compatible(endpoint)
|
use_openai = is_openai_compatible(endpoint)
|
||||||
# Normalize model name for tracking so it matches the PS table key
|
|
||||||
tracking_model = get_tracking_model(endpoint, model)
|
|
||||||
if use_openai:
|
if use_openai:
|
||||||
if ":latest" in model:
|
if ":latest" in model:
|
||||||
model = model.split(":latest")
|
model = model.split(":latest")
|
||||||
|
|
@ -1640,7 +1623,6 @@ async def proxy(request: Request):
|
||||||
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||||
else:
|
else:
|
||||||
client = ollama.AsyncClient(host=endpoint)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
await increment_usage(endpoint, tracking_model)
|
|
||||||
|
|
||||||
# 4. Async generator that streams data and decrements the counter
|
# 4. Async generator that streams data and decrements the counter
|
||||||
async def stream_generate_response():
|
async def stream_generate_response():
|
||||||
|
|
@ -1735,10 +1717,8 @@ async def chat_proxy(request: Request):
|
||||||
opt = True
|
opt = True
|
||||||
else:
|
else:
|
||||||
opt = False
|
opt = False
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint, tracking_model = await choose_endpoint(model)
|
||||||
use_openai = is_openai_compatible(endpoint)
|
use_openai = is_openai_compatible(endpoint)
|
||||||
# Normalize model name for tracking so it matches the PS table key
|
|
||||||
tracking_model = get_tracking_model(endpoint, model)
|
|
||||||
if use_openai:
|
if use_openai:
|
||||||
if ":latest" in model:
|
if ":latest" in model:
|
||||||
model = model.split(":latest")
|
model = model.split(":latest")
|
||||||
|
|
@ -1769,7 +1749,6 @@ async def chat_proxy(request: Request):
|
||||||
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||||
else:
|
else:
|
||||||
client = ollama.AsyncClient(host=endpoint)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
await increment_usage(endpoint, tracking_model)
|
|
||||||
# 3. Async generator that streams chat data and decrements the counter
|
# 3. Async generator that streams chat data and decrements the counter
|
||||||
async def stream_chat_response():
|
async def stream_chat_response():
|
||||||
try:
|
try:
|
||||||
|
|
@ -1861,10 +1840,8 @@ async def embedding_proxy(request: Request):
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint, tracking_model = await choose_endpoint(model)
|
||||||
use_openai = is_openai_compatible(endpoint)
|
use_openai = is_openai_compatible(endpoint)
|
||||||
# Normalize model name for tracking so it matches the PS table key
|
|
||||||
tracking_model = get_tracking_model(endpoint, model)
|
|
||||||
if use_openai:
|
if use_openai:
|
||||||
if ":latest" in model:
|
if ":latest" in model:
|
||||||
model = model.split(":latest")
|
model = model.split(":latest")
|
||||||
|
|
@ -1872,7 +1849,6 @@ async def embedding_proxy(request: Request):
|
||||||
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
|
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
|
||||||
else:
|
else:
|
||||||
client = ollama.AsyncClient(host=endpoint)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
await increment_usage(endpoint, tracking_model)
|
|
||||||
# 3. Async generator that streams embedding data and decrements the counter
|
# 3. Async generator that streams embedding data and decrements the counter
|
||||||
async def stream_embedding_response():
|
async def stream_embedding_response():
|
||||||
try:
|
try:
|
||||||
|
|
@ -1929,10 +1905,8 @@ async def embed_proxy(request: Request):
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint, tracking_model = await choose_endpoint(model)
|
||||||
use_openai = is_openai_compatible(endpoint)
|
use_openai = is_openai_compatible(endpoint)
|
||||||
# Normalize model name for tracking so it matches the PS table key
|
|
||||||
tracking_model = get_tracking_model(endpoint, model)
|
|
||||||
if use_openai:
|
if use_openai:
|
||||||
if ":latest" in model:
|
if ":latest" in model:
|
||||||
model = model.split(":latest")
|
model = model.split(":latest")
|
||||||
|
|
@ -1940,7 +1914,6 @@ async def embed_proxy(request: Request):
|
||||||
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
|
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
|
||||||
else:
|
else:
|
||||||
client = ollama.AsyncClient(host=endpoint)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
await increment_usage(endpoint, tracking_model)
|
|
||||||
# 3. Async generator that streams embed data and decrements the counter
|
# 3. Async generator that streams embed data and decrements the counter
|
||||||
async def stream_embedding_response():
|
async def stream_embedding_response():
|
||||||
try:
|
try:
|
||||||
|
|
@ -2038,8 +2011,7 @@ async def show_proxy(request: Request, model: Optional[str] = None):
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint, _ = await choose_endpoint(model, reserve=False)
|
||||||
#await increment_usage(endpoint, model)
|
|
||||||
|
|
||||||
client = ollama.AsyncClient(host=endpoint)
|
client = ollama.AsyncClient(host=endpoint)
|
||||||
|
|
||||||
|
|
@ -2628,10 +2600,7 @@ async def openai_embedding_proxy(request: Request):
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint, tracking_model = await choose_endpoint(model)
|
||||||
# Normalize model name for tracking so it matches the PS table key
|
|
||||||
tracking_model = get_tracking_model(endpoint, model)
|
|
||||||
await increment_usage(endpoint, tracking_model)
|
|
||||||
if is_openai_compatible(endpoint):
|
if is_openai_compatible(endpoint):
|
||||||
api_key = config.api_keys.get(endpoint, "no-key")
|
api_key = config.api_keys.get(endpoint, "no-key")
|
||||||
else:
|
else:
|
||||||
|
|
@ -2722,10 +2691,7 @@ async def openai_chat_completions_proxy(request: Request):
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint, tracking_model = await choose_endpoint(model)
|
||||||
# Normalize model name for tracking so it matches the PS table key
|
|
||||||
tracking_model = get_tracking_model(endpoint, model)
|
|
||||||
await increment_usage(endpoint, tracking_model)
|
|
||||||
base_url = ep2base(endpoint)
|
base_url = ep2base(endpoint)
|
||||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||||
# 3. Async generator that streams completions data and decrements the counter
|
# 3. Async generator that streams completions data and decrements the counter
|
||||||
|
|
@ -2868,10 +2834,7 @@ async def openai_completions_proxy(request: Request):
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||||
|
|
||||||
# 2. Endpoint logic
|
# 2. Endpoint logic
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint, tracking_model = await choose_endpoint(model)
|
||||||
# Normalize model name for tracking so it matches the PS table key
|
|
||||||
tracking_model = get_tracking_model(endpoint, model)
|
|
||||||
await increment_usage(endpoint, tracking_model)
|
|
||||||
base_url = ep2base(endpoint)
|
base_url = ep2base(endpoint)
|
||||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||||
|
|
||||||
|
|
@ -3056,12 +3019,13 @@ async def rerank_proxy(request: Request):
|
||||||
|
|
||||||
# Determine which endpoint serves this model
|
# Determine which endpoint serves this model
|
||||||
try:
|
try:
|
||||||
endpoint = await choose_endpoint(model)
|
endpoint, tracking_model = await choose_endpoint(model)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
# Ollama endpoints have no native rerank support
|
# Ollama endpoints have no native rerank support
|
||||||
if not is_openai_compatible(endpoint):
|
if not is_openai_compatible(endpoint):
|
||||||
|
await decrement_usage(endpoint, tracking_model)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=501,
|
status_code=501,
|
||||||
detail=(
|
detail=(
|
||||||
|
|
@ -3071,13 +3035,9 @@ async def rerank_proxy(request: Request):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Normalize model name for tracking
|
|
||||||
tracking_model = get_tracking_model(endpoint, model)
|
|
||||||
if ":latest" in model:
|
if ":latest" in model:
|
||||||
model = model.split(":latest")[0]
|
model = model.split(":latest")[0]
|
||||||
|
|
||||||
await increment_usage(endpoint, tracking_model)
|
|
||||||
|
|
||||||
# Build upstream rerank request body – forward only recognised fields
|
# Build upstream rerank request body – forward only recognised fields
|
||||||
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
|
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
|
||||||
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
|
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue