diff --git a/router.py b/router.py index 98f1fcb..7be341e 100644 --- a/router.py +++ b/router.py @@ -10,7 +10,8 @@ import json, random, asyncio, yaml, httpx, ollama, openai from pathlib import Path from typing import Dict, Set, List from fastapi import FastAPI, Request, HTTPException -from starlette.responses import StreamingResponse, JSONResponse, Response +from fastapi.staticfiles import StaticFiles +from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse from pydantic import Field from pydantic_settings import BaseSettings from collections import defaultdict @@ -71,12 +72,19 @@ async def fetch_available_models(endpoint: str) -> Set[str]: """ try: async with httpx.AsyncClient(timeout=1.0) as client: - resp = await client.get(f"{endpoint}/api/tags") + if "/v1" in endpoint: + resp = await client.get(f"{endpoint}/models") + else: + resp = await client.get(f"{endpoint}/api/tags") resp.raise_for_status() data = resp.json() # Expected format: # {"models": [{"name": "model1"}, {"name": "model2"}]} - return {m.get("name") for m in data.get("models", []) if m.get("name")} + if "/v1" in endpoint: + models = {m.get("id") for m in data.get("data", []) if m.get("name")} + else: + models = {m.get("name") for m in data.get("models", []) if m.get("name")} + return models except Exception: # Treat any error as if the endpoint offers no models return set() @@ -116,6 +124,13 @@ async def fetch_endpoint_details(endpoint: str, route: str, detail: str) -> List # If anything goes wrong we cannot reply versions return {detail: "N/A"} +def ep2base(ep): + if "/v1" in ep: + base_url = ep + else: + base_url = ep+"/v1" + return base_url + def dedupe_on_keys(dicts, key_fields): """ Helper function to deduplicate endpoint details based on given dict keys. @@ -412,7 +427,7 @@ async def embedding_proxy(request: Request): ) # ------------------------------------------------------------- -# 8. API route – Embed +# 9. API route – Embed # ------------------------------------------------------------- @app.post("/api/embed") async def embed_proxy(request: Request): @@ -468,7 +483,7 @@ async def embed_proxy(request: Request): ) # ------------------------------------------------------------- -# 9. API route – Create +# 10. API route – Create # ------------------------------------------------------------- @app.post("/api/create") async def create_proxy(request: Request): @@ -516,7 +531,7 @@ async def create_proxy(request: Request): return dict(final_status) # ------------------------------------------------------------- -# 10. API route – Show +# 11. API route – Show # ------------------------------------------------------------- @app.post("/api/show") async def show_proxy(request: Request): @@ -549,7 +564,7 @@ async def show_proxy(request: Request): return show # ------------------------------------------------------------- -# 11. API route – Copy +# 12. API route – Copy # ------------------------------------------------------------- @app.post("/api/copy") async def copy_proxy(request: Request): @@ -595,7 +610,7 @@ async def copy_proxy(request: Request): ) # ------------------------------------------------------------- -# 12. API route – Delete +# 13. API route – Delete # ------------------------------------------------------------- @app.delete("/api/delete") async def delete_proxy(request: Request): @@ -636,7 +651,7 @@ async def delete_proxy(request: Request): ) # ------------------------------------------------------------- -# 13. API route – Pull +# 14. API route – Pull # ------------------------------------------------------------- @app.post("/api/pull") async def pull_proxy(request: Request): @@ -676,7 +691,7 @@ async def pull_proxy(request: Request): return dict(final_status) # ------------------------------------------------------------- -# 14. API route – Push +# 15. API route – Push # ------------------------------------------------------------- @app.post("/api/push") async def push_proxy(request: Request): @@ -717,7 +732,7 @@ async def push_proxy(request: Request): # ------------------------------------------------------------- -# 15. API route – Version +# 16. API route – Version # ------------------------------------------------------------- @app.get("/api/version") async def version_proxy(request: Request): @@ -739,7 +754,7 @@ async def version_proxy(request: Request): ) # ------------------------------------------------------------- -# 16. API route – tags +# 17. API route – tags # ------------------------------------------------------------- @app.get("/api/tags") async def tags_proxy(request: Request): @@ -748,7 +763,8 @@ async def tags_proxy(request: Request): """ # 1. Query all endpoints for models - tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints] + tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] + tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep] all_models = await asyncio.gather(*tasks) models = {'models': []} @@ -762,7 +778,7 @@ async def tags_proxy(request: Request): ) # ------------------------------------------------------------- -# 17. API route – ps +# 18. API route – ps # ------------------------------------------------------------- @app.get("/api/ps") async def ps_proxy(request: Request): @@ -771,7 +787,7 @@ async def ps_proxy(request: Request): """ # 1. Query all endpoints for running models - tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints] + tasks = [fetch_endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep] loaded_models = await asyncio.gather(*tasks) models = {'models': []} @@ -785,7 +801,30 @@ async def ps_proxy(request: Request): ) # ------------------------------------------------------------- -# 18. API route – OpenAI compatible Embedding +# 19. Proxy usage route – for monitoring +# ------------------------------------------------------------- +@app.get("/api/usage") +async def usage_proxy(request: Request): + """ + Return a snapshot of the usage counter for each endpoint. + Useful for debugging / monitoring. + """ + return {"usage_counts": usage_counts} + +# ------------------------------------------------------------- +# 20. Proxy config route – for monitoring and frontent usage +# ------------------------------------------------------------- +@app.get("/api/config") +async def config_proxy(request: Request): + """ + Return a simple JSON object that contains the configured + Ollama endpoints. The front‑end uses this to display + which endpoints are being proxied. + """ + return {"endpoints": config.endpoints} + +# ------------------------------------------------------------- +# 21. API route – OpenAI compatible Embedding # ------------------------------------------------------------- @app.post("/v1/embeddings") async def openai_embedding_proxy(request: Request): @@ -826,7 +865,7 @@ async def openai_embedding_proxy(request: Request): return async_gen # ------------------------------------------------------------- -# 19. API route – OpenAI compatible Chat Completions +# 22. API route – OpenAI compatible Chat Completions # ------------------------------------------------------------- @app.post("/v1/chat/completions") async def openai_chat_completions_proxy(request: Request): @@ -851,8 +890,31 @@ async def openai_chat_completions_proxy(request: Request): temperature = payload.get("temperature") top_p = payload.get("top_p") max_tokens = payload.get("max_tokens") - tools =payload.get("tools") + tools = payload.get("tools") + + headers = request.headers + api_key = headers.get("Authorization") + api_key = api_key.split()[1] + params = { + "messages": messages, + "model": model, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "seed": seed, + "stop": stop, + "stream": stream, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens + } + + if tools is not None: + params["tools"] = tools + if response_format is not None: + params["response_format"] = response_format + if stream_options is not None: + params["stream_options"] = stream_options if not model: raise HTTPException( @@ -868,13 +930,14 @@ async def openai_chat_completions_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) await increment_usage(endpoint, model) - oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama") - + base_url = ep2base(endpoint) + oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key) + # 3. Async generator that streams completions data and decrements the counter async def stream_ochat_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) - async_gen = await oclient.chat.completions.create(messages=messages, model=model, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, response_format=response_format, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, tools=tools) + async_gen = await oclient.chat.completions.create(**params) if stream == True: async for chunk in async_gen: data = ( @@ -904,7 +967,7 @@ async def openai_chat_completions_proxy(request: Request): ) # ------------------------------------------------------------- -# 20. API route – OpenAI compatible Completions +# 23. API route – OpenAI compatible Completions # ------------------------------------------------------------- @app.post("/v1/completions") async def openai_completions_proxy(request: Request): @@ -928,8 +991,28 @@ async def openai_completions_proxy(request: Request): temperature = payload.get("temperature") top_p = payload.get("top_p") max_tokens = payload.get("max_tokens") - suffix =payload.get("suffix") + suffix = payload.get("suffix") + + headers = request.headers + api_key = headers.get("Authorization") + api_key = api_key.split()[1] + params = { + "prompt": prompt, + "model": model, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "seed": seed, + "stop": stop, + "stream": stream, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "suffix": suffix + } + + if stream_options is not None: + params["stream_options"] = stream_options if not model: raise HTTPException( @@ -945,13 +1028,14 @@ async def openai_completions_proxy(request: Request): # 2. Endpoint logic endpoint = await choose_endpoint(model) await increment_usage(endpoint, model) - oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama") + base_url = ep2base(endpoint) + oclient = openai.AsyncOpenAI(base_url=base_url, api_key=api_key) # 3. Async generator that streams completions data and decrements the counter async def stream_ocompletions_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) - async_gen = await oclient.completions.create(model=model, prompt=prompt, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, suffix=suffix) + async_gen = await oclient.completions.create(**params) if stream == True: async for chunk in async_gen: data = ( @@ -981,17 +1065,44 @@ async def openai_completions_proxy(request: Request): ) # ------------------------------------------------------------- -# 21. OpenAI API compatible endpoints #ToDo +# 24. OpenAI API compatible models endpoint # ------------------------------------------------------------- -@app.post("/v1/models") -async def not_implemented_yet(request: Request): +@app.get("/v1/models") +async def openai_models_proxy(request: Request): + """ + Proxy a models request to Ollama endpoints and reply with a unique list of all models. - return Response( - status_code=501 + """ + # 1. Query all endpoints for models + tasks = [fetch_endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep] + tasks += [fetch_endpoint_details(ep, "/models", "data") for ep in config.endpoints if "/v1" in ep] + all_models = await asyncio.gather(*tasks) + + models = {'data': []} + for modellist in all_models: + models['data'] += modellist + + # 2. Return a JSONResponse with a deduplicated list of unique models for inference + return JSONResponse( + content={"data": dedupe_on_keys(models['data'], ['name'])}, + status_code=200, ) # ------------------------------------------------------------- -# 22. FastAPI startup event – load configuration +# 25. Serve the static front‑end +# ------------------------------------------------------------- +app.mount("/static", StaticFiles(directory="static"), name="static") + +@app.get("/", response_class=HTMLResponse) +async def index(request: Request): + """ + Render the landing page that lists the configured endpoints + and the models available / running. + """ + return HTMLResponse(content=open("static/index.html", "r").read(), status_code=200) + +# ------------------------------------------------------------- +# 26. FastAPI startup event – load configuration # ------------------------------------------------------------- @app.on_event("startup") async def startup_event() -> None: