diff --git a/router.py b/router.py index d8a318a..5780045 100644 --- a/router.py +++ b/router.py @@ -394,10 +394,19 @@ async def choose_endpoint(model: str) -> str: # 6️⃣ if not candidate_endpoints: - raise RuntimeError( - f"None of the configured endpoints ({', '.join(config.endpoints)}) " - f"advertise the model '{model}'." - ) + if ":latest" in model: #ollama naming convention not applicable to openai + model = model.split(":") + model = model[0] + print(model) + candidate_endpoints = [ + ep for ep, models in zip(config.endpoints, advertised_sets) + if model in models + ] + if not candidate_endpoints: + raise RuntimeError( + f"None of the configured endpoints ({', '.join(config.endpoints)}) " + f"advertise the model '{model}'." + ) # 3️⃣ Among the candidates, find those that have the model *loaded* # (concurrently, but only for the filtered list) @@ -472,9 +481,11 @@ async def proxy(request: Request): endpoint = await choose_endpoint(model) - await increment_usage(endpoint, model) is_openai_endpoint = "/v1" in endpoint if is_openai_endpoint: + if ":latest" in model: + model = model.split(":") + model = model[0] params = { "prompt": prompt, "model": model, @@ -488,6 +499,7 @@ async def proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint]) else: client = ollama.AsyncClient(host=endpoint) + await increment_usage(endpoint, model) # 4. Async generator that streams data and decrements the counter async def stream_generate_response(): @@ -564,9 +576,11 @@ async def chat_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) - await increment_usage(endpoint, model) is_openai_endpoint = "/v1" in endpoint if is_openai_endpoint: + if ":latest" in model: + model = model.split(":") + model = model[0] params = { "messages": messages, "model": model, @@ -581,7 +595,7 @@ async def chat_proxy(request: Request): oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint]) else: client = ollama.AsyncClient(host=endpoint) - + await increment_usage(endpoint, model) # 3. Async generator that streams chat data and decrements the counter async def stream_chat_response(): try: