113 lines
3.4 KiB
Python
113 lines
3.4 KiB
Python
"""Endpoint URL, model-name, and endpoint-classification helpers.
|
|
|
|
The endpoint classifiers read live config via ``get_config()`` so that the
|
|
startup-time rebind of ``config`` in router.py is picked up at call time.
|
|
"""
|
|
from config import get_config
|
|
|
|
|
|
def _normalize_llama_model_name(name: str) -> str:
|
|
"""Extract the model name from a huggingface-style identifier.
|
|
e.g. 'unsloth/gpt-oss-20b-GGUF:F16' -> 'gpt-oss-20b-GGUF'
|
|
"""
|
|
if "/" in name:
|
|
name = name.rsplit("/", 1)[1]
|
|
if ":" in name:
|
|
name = name.split(":")[0]
|
|
return name
|
|
|
|
|
|
def _extract_llama_quant(name: str) -> str:
|
|
"""Extract the quantization level from a huggingface-style identifier.
|
|
e.g. 'unsloth/gpt-oss-20b-GGUF:Q8_0' -> 'Q8_0'
|
|
Returns empty string if no quant suffix is present.
|
|
"""
|
|
if ":" in name:
|
|
return name.rsplit(":", 1)[1]
|
|
return ""
|
|
|
|
|
|
def ep2base(ep):
|
|
if "/v1" in ep:
|
|
base_url = ep
|
|
else:
|
|
base_url = ep + "/v1"
|
|
return base_url
|
|
|
|
|
|
def dedupe_on_keys(dicts, key_fields):
|
|
"""
|
|
Helper function to deduplicate endpoint details based on given dict keys.
|
|
"""
|
|
seen = set()
|
|
out = []
|
|
for d in dicts:
|
|
# Build a tuple of the values for the chosen keys
|
|
key = tuple(d.get(k) for k in key_fields)
|
|
if key not in seen:
|
|
seen.add(key)
|
|
out.append(d)
|
|
return out
|
|
|
|
|
|
def is_ext_openai_endpoint(endpoint: str) -> bool:
|
|
"""
|
|
Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama or llama-server).
|
|
|
|
Returns True for:
|
|
- External services like OpenAI.com, Groq, etc.
|
|
|
|
Returns False for:
|
|
- Ollama endpoints (without /v1, or with /v1 but default port 11434)
|
|
- llama-server endpoints (explicitly configured in llama_server_endpoints)
|
|
"""
|
|
cfg = get_config()
|
|
# Check if it's a llama-server endpoint (has /v1 and is in the configured list)
|
|
if endpoint in cfg.llama_server_endpoints:
|
|
return False
|
|
|
|
if "/v1" not in endpoint:
|
|
return False
|
|
|
|
base_endpoint = endpoint.replace('/v1', '')
|
|
if base_endpoint in cfg.endpoints:
|
|
return False # It's Ollama's /v1
|
|
|
|
# Check for default Ollama port
|
|
if ':11434' in endpoint:
|
|
return False # It's Ollama
|
|
|
|
return True # It's an external OpenAI endpoint
|
|
|
|
|
|
def is_openai_compatible(endpoint: str) -> bool:
|
|
"""
|
|
Return True if the endpoint speaks the OpenAI API (not native Ollama).
|
|
This includes external OpenAI endpoints AND llama-server endpoints.
|
|
"""
|
|
return "/v1" in endpoint or endpoint in get_config().llama_server_endpoints
|
|
|
|
|
|
def get_tracking_model(endpoint: str, model: str) -> str:
|
|
"""
|
|
Normalize model name for tracking purposes so it matches the PS table key.
|
|
|
|
- For llama-server endpoints: strips HF prefix and quantization suffix
|
|
- For Ollama endpoints: appends ":latest" if no version suffix is present
|
|
- For external OpenAI endpoints: returns as-is (not shown in PS)
|
|
|
|
This ensures consistent model naming across all routes for usage tracking.
|
|
"""
|
|
# External OpenAI endpoints are not shown in PS, keep as-is
|
|
if is_ext_openai_endpoint(endpoint):
|
|
return model
|
|
|
|
# llama-server endpoints use normalized names in PS
|
|
if endpoint in get_config().llama_server_endpoints:
|
|
return _normalize_llama_model_name(model)
|
|
|
|
# Ollama endpoints: append ":latest" if no version suffix
|
|
if ":" not in model:
|
|
return model + ":latest"
|
|
|
|
return model
|