diff --git a/router.py b/router.py index d1f5797..e1d42bd 100644 --- a/router.py +++ b/router.py @@ -417,6 +417,30 @@ def is_openai_compatible(endpoint: str) -> bool: """ return "/v1" in endpoint or endpoint in config.llama_server_endpoints +def get_tracking_model(endpoint: str, model: str) -> str: + """ + Normalize model name for tracking purposes so it matches the PS table key. + + - For llama-server endpoints: strips HF prefix and quantization suffix + - For Ollama endpoints: appends ":latest" if no version suffix is present + - For external OpenAI endpoints: returns as-is (not shown in PS) + + This ensures consistent model naming across all routes for usage tracking. + """ + # External OpenAI endpoints are not shown in PS, keep as-is + if is_ext_openai_endpoint(endpoint): + return model + + # llama-server endpoints use normalized names in PS + if endpoint in config.llama_server_endpoints: + return _normalize_llama_model_name(model) + + # Ollama endpoints: append ":latest" if no version suffix + if ":" not in model: + return model + ":latest" + + return model + async def token_worker() -> None: try: while True: @@ -935,7 +959,9 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, model) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) + await increment_usage(endpoint, tracking_model) try: if use_openai: @@ -958,7 +984,7 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No if llama_usage: prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) # Convert to Ollama format if chunks: response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts) @@ -976,7 +1002,7 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No if llama_usage: prompt_tok, comp_tok = llama_usage if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) response = rechunk.openai_chat_completion2ollama(response, stream, start_ts) else: response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive) @@ -988,18 +1014,18 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) if chunks: response = chunks[-1] else: prompt_tok = response.prompt_eval_count or 0 comp_tok = response.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) return response finally: - await decrement_usage(endpoint, model) + await decrement_usage(endpoint, tracking_model) def get_last_user_content(messages): """ @@ -1594,6 +1620,8 @@ async def proxy(request: Request): endpoint = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1618,7 +1646,7 @@ async def proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, model) + await increment_usage(endpoint, tracking_model) # 4. Async generator that streams data and decrements the counter async def stream_generate_response(): @@ -1635,7 +1663,7 @@ async def proxy(request: Request): prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: @@ -1650,7 +1678,7 @@ async def proxy(request: Request): prompt_tok = async_gen.prompt_eval_count or 0 comp_tok = async_gen.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) json_line = ( response if hasattr(async_gen, "model_dump_json") @@ -1660,7 +1688,7 @@ async def proxy(request: Request): finally: # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, model) + await decrement_usage(endpoint, tracking_model) # 5. Return a StreamingResponse backed by the generator return StreamingResponse( @@ -1715,6 +1743,8 @@ async def chat_proxy(request: Request): opt = False endpoint = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1745,7 +1775,7 @@ async def chat_proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, model) + await increment_usage(endpoint, tracking_model) # 3. Async generator that streams chat data and decrements the counter async def stream_chat_response(): try: @@ -1772,7 +1802,7 @@ async def chat_proxy(request: Request): prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: @@ -1787,7 +1817,7 @@ async def chat_proxy(request: Request): prompt_tok = async_gen.prompt_eval_count or 0 comp_tok = async_gen.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, model, prompt_tok, comp_tok)) + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) json_line = ( response if hasattr(async_gen, "model_dump_json") @@ -1797,7 +1827,7 @@ async def chat_proxy(request: Request): finally: # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, model) + await decrement_usage(endpoint, tracking_model) # 4. Return a StreamingResponse backed by the generator media_type = "application/x-ndjson" if stream else "application/json" @@ -1839,6 +1869,8 @@ async def embedding_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1846,7 +1878,7 @@ async def embedding_proxy(request: Request): client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, model) + await increment_usage(endpoint, tracking_model) # 3. Async generator that streams embedding data and decrements the counter async def stream_embedding_response(): try: @@ -1863,7 +1895,7 @@ async def embedding_proxy(request: Request): yield json_line.encode("utf-8") + b"\n" finally: # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, model) + await decrement_usage(endpoint, tracking_model) # 5. Return a StreamingResponse backed by the generator return StreamingResponse( @@ -1905,6 +1937,8 @@ async def embed_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) if use_openai: if ":latest" in model: model = model.split(":latest") @@ -1912,7 +1946,7 @@ async def embed_proxy(request: Request): client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) - await increment_usage(endpoint, model) + await increment_usage(endpoint, tracking_model) # 3. Async generator that streams embed data and decrements the counter async def stream_embedding_response(): try: @@ -1929,7 +1963,7 @@ async def embed_proxy(request: Request): yield json_line.encode("utf-8") + b"\n" finally: # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, model) + await decrement_usage(endpoint, tracking_model) # 4. Return a StreamingResponse backed by the generator return StreamingResponse( @@ -2601,7 +2635,9 @@ async def openai_embedding_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) - await increment_usage(endpoint, model) + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) + await increment_usage(endpoint, tracking_model) if is_openai_compatible(endpoint): api_key = config.api_keys.get(endpoint, "no-key") else: @@ -2612,8 +2648,8 @@ async def openai_embedding_proxy(request: Request): # 3. Async generator that streams embedding data and decrements the counter async_gen = await oclient.embeddings.create(input=doc, model=model) - - await decrement_usage(endpoint, model) + + await decrement_usage(endpoint, tracking_model) # 5. Return a StreamingResponse backed by the generator return async_gen @@ -2690,15 +2726,8 @@ async def openai_chat_completions_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) - # Normalize model name for tracking so it matches the PS table key: - # - Ollama: PS reports "model:latest" → append ":latest" when missing - # - llama-server: PS reports _normalize_llama_model_name(id) → strip HF prefix/quant - # - External OpenAI: not shown in PS, keep as-is - tracking_model = model - if endpoint in config.llama_server_endpoints: - tracking_model = _normalize_llama_model_name(model) - elif not is_ext_openai_endpoint(endpoint) and ":" not in model: - tracking_model = model + ":latest" + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) await increment_usage(endpoint, tracking_model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) @@ -2843,15 +2872,8 @@ async def openai_completions_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) - # Normalize model name for tracking so it matches the PS table key: - # - Ollama: PS reports "model:latest" → append ":latest" when missing - # - llama-server: PS reports _normalize_llama_model_name(id) → strip HF prefix/quant - # - External OpenAI: not shown in PS, keep as-is - tracking_model = model - if endpoint in config.llama_server_endpoints: - tracking_model = _normalize_llama_model_name(model) - elif not is_ext_openai_endpoint(endpoint) and ":" not in model: - tracking_model = model + ":latest" + # Normalize model name for tracking so it matches the PS table key + tracking_model = get_tracking_model(endpoint, model) await increment_usage(endpoint, tracking_model) base_url = ep2base(endpoint) oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))