feat(router): normalize model names for usage tracking across endpoints (continued)

Introduce `get_tracking_model()` to standardize model names for consistent usage tracking in Prometheus metrics. This ensures llama-server models are stripped of HF prefixes and quantization suffixes, Ollama models append `:latest` when versionless, and external OpenAI models remain unchanged—aligning all tracking keys with the PS table.
This commit is contained in:
Alpha Nerd 2026-02-18 11:45:37 +01:00
parent b2980a7d24
commit 7cba67cce0

100
router.py
View file

@ -417,6 +417,30 @@ def is_openai_compatible(endpoint: str) -> bool:
"""
return "/v1" in endpoint or endpoint in config.llama_server_endpoints
def get_tracking_model(endpoint: str, model: str) -> str:
"""
Normalize model name for tracking purposes so it matches the PS table key.
- For llama-server endpoints: strips HF prefix and quantization suffix
- For Ollama endpoints: appends ":latest" if no version suffix is present
- For external OpenAI endpoints: returns as-is (not shown in PS)
This ensures consistent model naming across all routes for usage tracking.
"""
# External OpenAI endpoints are not shown in PS, keep as-is
if is_ext_openai_endpoint(endpoint):
return model
# llama-server endpoints use normalized names in PS
if endpoint in config.llama_server_endpoints:
return _normalize_llama_model_name(model)
# Ollama endpoints: append ":latest" if no version suffix
if ":" not in model:
return model + ":latest"
return model
async def token_worker() -> None:
try:
while True:
@ -935,7 +959,9 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
try:
if use_openai:
@ -958,7 +984,7 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Convert to Ollama format
if chunks:
response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
@ -976,7 +1002,7 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
response = rechunk.openai_chat_completion2ollama(response, stream, start_ts)
else:
response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
@ -988,18 +1014,18 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if chunks:
response = chunks[-1]
else:
prompt_tok = response.prompt_eval_count or 0
comp_tok = response.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
return response
finally:
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
def get_last_user_content(messages):
"""
@ -1594,6 +1620,8 @@ async def proxy(request: Request):
endpoint = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1618,7 +1646,7 @@ async def proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
await increment_usage(endpoint, tracking_model)
# 4. Async generator that streams data and decrements the counter
async def stream_generate_response():
@ -1635,7 +1663,7 @@ async def proxy(request: Request):
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
@ -1650,7 +1678,7 @@ async def proxy(request: Request):
prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
response
if hasattr(async_gen, "model_dump_json")
@ -1660,7 +1688,7 @@ async def proxy(request: Request):
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
# 5. Return a StreamingResponse backed by the generator
return StreamingResponse(
@ -1715,6 +1743,8 @@ async def chat_proxy(request: Request):
opt = False
endpoint = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1745,7 +1775,7 @@ async def chat_proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams chat data and decrements the counter
async def stream_chat_response():
try:
@ -1772,7 +1802,7 @@ async def chat_proxy(request: Request):
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
@ -1787,7 +1817,7 @@ async def chat_proxy(request: Request):
prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
response
if hasattr(async_gen, "model_dump_json")
@ -1797,7 +1827,7 @@ async def chat_proxy(request: Request):
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
media_type = "application/x-ndjson" if stream else "application/json"
@ -1839,6 +1869,8 @@ async def embedding_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1846,7 +1878,7 @@ async def embedding_proxy(request: Request):
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams embedding data and decrements the counter
async def stream_embedding_response():
try:
@ -1863,7 +1895,7 @@ async def embedding_proxy(request: Request):
yield json_line.encode("utf-8") + b"\n"
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
# 5. Return a StreamingResponse backed by the generator
return StreamingResponse(
@ -1905,6 +1937,8 @@ async def embed_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1912,7 +1946,7 @@ async def embed_proxy(request: Request):
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams embed data and decrements the counter
async def stream_embedding_response():
try:
@ -1929,7 +1963,7 @@ async def embed_proxy(request: Request):
yield json_line.encode("utf-8") + b"\n"
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
@ -2601,7 +2635,9 @@ async def openai_embedding_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
if is_openai_compatible(endpoint):
api_key = config.api_keys.get(endpoint, "no-key")
else:
@ -2612,8 +2648,8 @@ async def openai_embedding_proxy(request: Request):
# 3. Async generator that streams embedding data and decrements the counter
async_gen = await oclient.embeddings.create(input=doc, model=model)
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
# 5. Return a StreamingResponse backed by the generator
return async_gen
@ -2690,15 +2726,8 @@ async def openai_chat_completions_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
# Normalize model name for tracking so it matches the PS table key:
# - Ollama: PS reports "model:latest" → append ":latest" when missing
# - llama-server: PS reports _normalize_llama_model_name(id) → strip HF prefix/quant
# - External OpenAI: not shown in PS, keep as-is
tracking_model = model
if endpoint in config.llama_server_endpoints:
tracking_model = _normalize_llama_model_name(model)
elif not is_ext_openai_endpoint(endpoint) and ":" not in model:
tracking_model = model + ":latest"
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
@ -2843,15 +2872,8 @@ async def openai_completions_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
# Normalize model name for tracking so it matches the PS table key:
# - Ollama: PS reports "model:latest" → append ":latest" when missing
# - llama-server: PS reports _normalize_llama_model_name(id) → strip HF prefix/quant
# - External OpenAI: not shown in PS, keep as-is
tracking_model = model
if endpoint in config.llama_server_endpoints:
tracking_model = _normalize_llama_model_name(model)
elif not is_ext_openai_endpoint(endpoint) and ":" not in model:
tracking_model = model + ":latest"
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))