diff --git a/router.py b/router.py index e49152e..eb2a90a 100644 --- a/router.py +++ b/router.py @@ -119,111 +119,112 @@ async def _ensure_success(resp: aiohttp.ClientResponse) -> None: text = await resp.text() raise HTTPException(status_code=resp.status, detail=text) -async def fetch_available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: - """ - Query /api/tags and return a set of all model names that the - endpoint *advertises* (i.e. is capable of serving). This endpoint lists - every model that is installed on the Ollama instance, regardless of - whether the model is currently loaded into memory. +class fetch: + async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: + """ + Query /api/tags and return a set of all model names that the + endpoint *advertises* (i.e. is capable of serving). This endpoint lists + every model that is installed on the Ollama instance, regardless of + whether the model is currently loaded into memory. - If the request fails (e.g. timeout, 5xx, or malformed response), an empty - set is returned. - """ - headers = None - if api_key is not None: - headers = {"Authorization": "Bearer " + api_key} + If the request fails (e.g. timeout, 5xx, or malformed response), an empty + set is returned. + """ + headers = None + if api_key is not None: + headers = {"Authorization": "Bearer " + api_key} - if endpoint in _models_cache: - models, cached_at = _models_cache[endpoint] - if _is_fresh(cached_at, 300): - return models - else: - # stale entry – drop it - del _models_cache[endpoint] - - if endpoint in _error_cache: - if _is_fresh(_error_cache[endpoint], 1): - # Still within the short error TTL – pretend nothing is available - return set() - else: - # Error expired – remove it - del _error_cache[endpoint] - - if "/v1" in endpoint: - endpoint_url = f"{endpoint}/models" - key = "data" - else: - endpoint_url = f"{endpoint}/api/tags" - key = "models" - client: aiohttp.ClientSession = app_state["session"] - try: - async with client.get(endpoint_url, headers=headers) as resp: - await _ensure_success(resp) - data = await resp.json() - - items = data.get(key, []) - models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")} - - if models: - _models_cache[endpoint] = (models, time.time()) + if endpoint in _models_cache: + models, cached_at = _models_cache[endpoint] + if _is_fresh(cached_at, 300): return models else: - # Empty list – treat as “no models”, but still cache for 300s - _models_cache[endpoint] = (models, time.time()) - return models - except Exception as e: - # Treat any error as if the endpoint offers no models - print(f"[fetch_available_models] {endpoint} error: {e}") - _error_cache[endpoint] = time.time() - return set() + # stale entry – drop it + del _models_cache[endpoint] + + if endpoint in _error_cache: + if _is_fresh(_error_cache[endpoint], 1): + # Still within the short error TTL – pretend nothing is available + return set() + else: + # Error expired – remove it + del _error_cache[endpoint] + + if "/v1" in endpoint: + endpoint_url = f"{endpoint}/models" + key = "data" + else: + endpoint_url = f"{endpoint}/api/tags" + key = "models" + client: aiohttp.ClientSession = app_state["session"] + try: + async with client.get(endpoint_url, headers=headers) as resp: + await _ensure_success(resp) + data = await resp.json() + + items = data.get(key, []) + models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")} + + if models: + _models_cache[endpoint] = (models, time.time()) + return models + else: + # Empty list – treat as “no models”, but still cache for 300s + _models_cache[endpoint] = (models, time.time()) + return models + except Exception as e: + # Treat any error as if the endpoint offers no models + print(f"[fetch.available_models] {endpoint} error: {e}") + _error_cache[endpoint] = time.time() + return set() -async def fetch_loaded_models(endpoint: str) -> Set[str]: - """ - Query /api/ps and return a set of model names that are currently - loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty - set is returned. - """ - client: aiohttp.ClientSession = app_state["session"] - try: - async with client.get(f"{endpoint}/api/ps") as resp: - await _ensure_success(resp) - data = await resp.json() - # The response format is: - # {"models": [{"name": "model1"}, {"name": "model2"}]} - models = {m.get("name") for m in data.get("models", []) if m.get("name")} - return models - except Exception: - # If anything goes wrong we simply assume the endpoint has no models - return set() + async def loaded_models(endpoint: str) -> Set[str]: + """ + Query /api/ps and return a set of model names that are currently + loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty + set is returned. + """ + client: aiohttp.ClientSession = app_state["session"] + try: + async with client.get(f"{endpoint}/api/ps") as resp: + await _ensure_success(resp) + data = await resp.json() + # The response format is: + # {"models": [{"name": "model1"}, {"name": "model2"}]} + models = {m.get("name") for m in data.get("models", []) if m.get("name")} + return models + except Exception: + # If anything goes wrong we simply assume the endpoint has no models + return set() -async def fetch_endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]: - """ - Query / to fetch and return a List of dicts with details - for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail. - """ - client: aiohttp.ClientSession = app_state["session"] - headers = None - if api_key is not None: - headers = {"Authorization": "Bearer " + api_key} - - try: - async with client.get(f"{endpoint}{route}", headers=headers) as resp: - await _ensure_success(resp) - data = await resp.json() - detail = data.get(detail, []) - return detail - except Exception as e: - # If anything goes wrong we cannot reply details - print(e) - return [] + async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None) -> List[dict]: + """ + Query / to fetch and return a List of dicts with details + for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail. + """ + client: aiohttp.ClientSession = app_state["session"] + headers = None + if api_key is not None: + headers = {"Authorization": "Bearer " + api_key} + + try: + async with client.get(f"{endpoint}{route}", headers=headers) as resp: + await _ensure_success(resp) + data = await resp.json() + detail = data.get(detail, []) + return detail + except Exception as e: + # If anything goes wrong we cannot reply details + print(e) + return [] -def ep2base(ep): - if "/v1" in ep: - base_url = ep - else: - base_url = ep+"/v1" - return base_url + def ep2base(ep): + if "/v1" in ep: + base_url = ep + else: + base_url = ep+"/v1" + return base_url def dedupe_on_keys(dicts, key_fields): """ @@ -288,6 +289,19 @@ class rechunk: else: rechunk["message"] = {"role": chunk.choices[0].message.role, "content": chunk.choices[0].message.content, "thinking": None, "images": None, "tool_name": None, "tool_calls": None} return rechunk + + def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float): + rechunk = { "model": chunk.model, + "created_at": iso8601_ns(), + "load_duration": None, + "done_reason": chunk.choices[0].finish_reason, + "total_duration": None, + "eval_duration": (int((time.perf_counter() - start_ts) * 1000) if chunk.usage is not None else None), + "thinking": None, + "context": None, + "response": chunk.choices[0].text + } + return rechunk # ------------------------------------------------------------------ # SSE Helpser @@ -350,8 +364,8 @@ async def choose_endpoint(model: str) -> str: 6️⃣ If no endpoint advertises the model at all, raise an error. """ # 1️⃣ Gather advertised‑model sets for all endpoints concurrently - tag_tasks = [fetch_available_models(ep) for ep in config.endpoints if "/v1" not in ep] - tag_tasks += [fetch_available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] + tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if "/v1" not in ep] + tag_tasks += [fetch.available_models(ep, config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] advertised_sets = await asyncio.gather(*tag_tasks) # 2️⃣ Filter endpoints that advertise the requested model @@ -369,7 +383,7 @@ async def choose_endpoint(model: str) -> str: # 3️⃣ Among the candidates, find those that have the model *loaded* # (concurrently, but only for the filtered list) - load_tasks = [fetch_loaded_models(ep) for ep in candidate_endpoints] + load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] loaded_sets = await asyncio.gather(*load_tasks) async with usage_lock: @@ -409,7 +423,6 @@ async def proxy(request: Request): """ Proxy a generate request to Ollama and stream the response back to the client. """ - # 1. Parse and validate request try: body_bytes = await request.body() payload = json.loads(body_bytes.decode("utf-8")) @@ -439,29 +452,50 @@ async def proxy(request: Request): except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - # 2. Decide which endpoint to use + endpoint = await choose_endpoint(model) - - # Increment usage counter for this endpoint‑model pair await increment_usage(endpoint, model) + is_openai_endpoint = "/v1" in endpoint + if is_openai_endpoint: + params = { + "prompt": prompt, + "model": model, + } - # 3. Create Ollama client instance - client = ollama.AsyncClient(host=endpoint) + optional_params = { + "stream": stream, + } + + params.update({k: v for k, v in optional_params.items() if v is not None}) + oclient = openai.AsyncOpenAI(base_url=endpoint, api_key=config.api_keys[endpoint]) + else: + client = ollama.AsyncClient(host=endpoint) # 4. Async generator that streams data and decrements the counter async def stream_generate_response(): try: - async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive) + if is_openai_endpoint: + start_ts = time.perf_counter() + async_gen = await oclient.completions.create(**params) + else: + async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive) if stream == True: async for chunk in async_gen: + if is_openai_endpoint: + chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: json_line = json.dumps(chunk) yield json_line.encode("utf-8") + b"\n" else: + if is_openai_endpoint: + response = rechunk.openai_completion2ollama(async_gen, stream, start_ts) + response = json.dumps(response) + else: + response = async_gen.model_dump_json() json_line = ( - async_gen.model_dump_json() + response if hasattr(async_gen, "model_dump_json") else json.dumps(async_gen) ) @@ -513,7 +547,8 @@ async def chat_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) await increment_usage(endpoint, model) - if "/v1" in endpoint: + is_openai_endpoint = "/v1" in endpoint + if is_openai_endpoint: params = { "messages": messages, "model": model, @@ -530,7 +565,6 @@ async def chat_proxy(request: Request): client = ollama.AsyncClient(host=endpoint) # 3. Async generator that streams chat data and decrements the counter - is_openai_endpoint = "/v1" in endpoint async def stream_chat_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) @@ -560,7 +594,6 @@ async def chat_proxy(request: Request): if hasattr(async_gen, "model_dump_json") else json.dumps(async_gen) ) - print(json_line) yield json_line.encode("utf-8") + b"\n" finally: @@ -941,7 +974,7 @@ async def version_proxy(request: Request): """ # 1. Query all endpoints for version - tasks = [fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep] + tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep] all_versions = await asyncio.gather(*tasks) def version_key(v): @@ -964,8 +997,8 @@ async def tags_proxy(request: Request): """ # 1. Query all endpoints for models - tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] - tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] + tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] + tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] all_models = await asyncio.gather(*tasks) models = {'models': []} @@ -988,7 +1021,7 @@ async def ps_proxy(request: Request): """ # 1. Query all endpoints for running models - tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep] + tasks = [fetch.endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep] loaded_models = await asyncio.gather(*tasks) models = {'models': []} @@ -1300,8 +1333,8 @@ async def openai_models_proxy(request: Request): """ # 1. Query all endpoints for models - tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] - tasks += [fetch_endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] + tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] + tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep] all_models = await asyncio.gather(*tasks) models = {'data': []} @@ -1351,9 +1384,7 @@ async def health_proxy(request: Request): * The HTTP status code is 200 when everything is healthy, 503 otherwise. """ # Run all health checks in parallel - tasks = [ - fetch_endpoint_details(ep, "/api/version", "version") for ep in config.endpoints - ] + tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints] results = await asyncio.gather(*tasks, return_exceptions=True)