diff --git a/requirements.txt b/requirements.txt index 988d1df..e39b50c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,13 +2,16 @@ annotated-types==0.7.0 anyio==4.10.0 certifi==2025.8.3 click==8.2.1 +distro==1.9.0 exceptiongroup==1.3.0 fastapi==0.116.1 h11==0.16.0 httpcore==1.0.9 httpx==0.28.1 idna==3.10 +jiter==0.10.0 ollama==0.5.3 +openai==1.102.0 pydantic==2.11.7 pydantic-settings==2.10.1 pydantic_core==2.33.2 @@ -16,6 +19,7 @@ python-dotenv==1.1.1 PyYAML==6.0.2 sniffio==1.3.1 starlette==0.47.2 +tqdm==4.67.1 typing-inspection==0.4.1 typing_extensions==4.14.1 uvicorn==0.35.0 diff --git a/router.py b/router.py index c12cd33..169f190 100644 --- a/router.py +++ b/router.py @@ -6,7 +6,7 @@ version: 0.1 license: AGPL """ # ------------------------------------------------------------- -import json, random, asyncio, yaml, httpx, ollama +import json, random, asyncio, yaml, httpx, ollama, openai from pathlib import Path from typing import Dict, Set, List from fastapi import FastAPI, Request, HTTPException @@ -314,7 +314,7 @@ async def chat_proxy(request: Request): ) # ------------------------------------------------------------- -# 8. API route – Embedding +# 8. API route – Embedding - deprecated # ------------------------------------------------------------- @app.post("/api/embeddings") async def embedding_proxy(request: Request): @@ -742,12 +742,205 @@ async def ps_proxy(request: Request): ) # ------------------------------------------------------------- -# 18. OpenAI API compatible endpoints #ToDo +# 18. API route – OpenAI compatible Embedding +# ------------------------------------------------------------- +@app.post("/v1/embeddings") +async def openai_embedding_proxy(request: Request): + """ + Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings. + + """ + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = json.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + input = payload.get("input") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not input: + raise HTTPException( + status_code=400, detail="Missing required field 'input'" + ) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Endpoint logic + endpoint = await choose_endpoint(model) + await increment_usage(endpoint, model) + oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama") + + # 3. Async generator that streams embedding data and decrements the counter + async_gen = await oclient.embeddings.create(input = [input], model=model) + + await decrement_usage(endpoint, model) + + # 5. Return a StreamingResponse backed by the generator + return async_gen + +# ------------------------------------------------------------- +# 19. API route – OpenAI compatible Chat Completions # ------------------------------------------------------------- @app.post("/v1/chat/completions") +async def openai_chat_completions_proxy(request: Request): + """ + Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response. + + """ + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = json.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + messages = payload.get("messages") + frequency_penalty = payload.get("frequency_penalty") + presence_penalty = payload.get("presence_penalty") + response_format = payload.get("response_format") + seed = payload.get("seed") + stop = payload.get("stop") + stream = payload.get("stream") + stream_options = payload.get("stream_options") + temperature = payload.get("temperature") + top_p = payload.get("top_p") + max_tokens = payload.get("max_tokens") + tools =payload.get("tools") + + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not isinstance(messages, list): + raise HTTPException( + status_code=400, detail="Missing required field 'messages' (must be a list)" + ) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Endpoint logic + endpoint = await choose_endpoint(model) + await increment_usage(endpoint, model) + oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama") + + # 3. Async generator that streams completions data and decrements the counter + async def stream_ochat_response(): + try: + # The chat method returns a generator of dicts (or GenerateResponse) + async_gen = await oclient.chat.completions.create(messages=messages, model=model, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, response_format=response_format, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, tools=tools) + if stream == True: + async for chunk in async_gen: + data = ( + chunk.model_dump_json() + if hasattr(chunk, "model_dump_json") + else json.dumps(chunk) + ) + yield f"data: {data}\n\n".encode("utf-8") + # Final DONE event + yield b"data: [DONE]\n\n" + else: + json_line = ( + async_gen.model_dump_json() + if hasattr(async_gen, "model_dump_json") + else json.dumps(async_gen) + ) + yield json_line.encode("utf-8") + b"\n" + + finally: + # Ensure counter is decremented even if an exception occurs + await decrement_usage(endpoint, model) + + # 4. Return a StreamingResponse backed by the generator + return StreamingResponse( + stream_ochat_response(), + media_type="application/json", + ) + +# ------------------------------------------------------------- +# 20. API route – OpenAI compatible Completions +# ------------------------------------------------------------- @app.post("/v1/completions") +async def openai_completions_proxy(request: Request): + """ + Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response. + + """ + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = json.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + prompt = payload.get("prompt") + frequency_penalty = payload.get("frequency_penalty") + presence_penalty = payload.get("presence_penalty") + seed = payload.get("seed") + stop = payload.get("stop") + stream = payload.get("stream") + stream_options = payload.get("stream_options") + temperature = payload.get("temperature") + top_p = payload.get("top_p") + max_tokens = payload.get("max_tokens") + suffix =payload.get("suffix") + + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not prompt: + raise HTTPException( + status_code=400, detail="Missing required field 'prompt'" + ) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Endpoint logic + endpoint = await choose_endpoint(model) + await increment_usage(endpoint, model) + oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama") + + # 3. Async generator that streams completions data and decrements the counter + async def stream_ocompletions_response(): + try: + # The chat method returns a generator of dicts (or GenerateResponse) + async_gen = await oclient.completions.create(model=model, prompt=prompt, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, suffix=suffix) + if stream == True: + async for chunk in async_gen: + data = ( + chunk.model_dump_json() + if hasattr(chunk, "model_dump_json") + else json.dumps(chunk) + ) + yield f"data: {data}\n\n".encode("utf-8") + # Final DONE event + yield b"data: [DONE]\n\n" + else: + json_line = ( + async_gen.model_dump_json() + if hasattr(async_gen, "model_dump_json") + else json.dumps(async_gen) + ) + yield json_line.encode("utf-8") + b"\n" + + finally: + # Ensure counter is decremented even if an exception occurs + await decrement_usage(endpoint, model) + + # 4. Return a StreamingResponse backed by the generator + return StreamingResponse( + stream_ocompletions_response(), + media_type="application/json", + ) + +# ------------------------------------------------------------- +# 21. OpenAI API compatible endpoints #ToDo +# ------------------------------------------------------------- @app.post("/v1/models") -@app.post("/v1/embeddings") async def not_implemented_yet(request: Request): return Response( @@ -755,7 +948,7 @@ async def not_implemented_yet(request: Request): ) # ------------------------------------------------------------- -# 19. FastAPI startup event – load configuration +# 22. FastAPI startup event – load configuration # ------------------------------------------------------------- @app.on_event("startup") async def startup_event() -> None: @@ -763,4 +956,4 @@ async def startup_event() -> None: # Load YAML config (or use defaults if not present) config = Config.from_yaml(Path("config.yaml")) print(f"Loaded configuration:\n endpoints={config.endpoints},\n " - f"max_concurrent_connections={config.max_concurrent_connections}") + f"max_concurrent_connections={config.max_concurrent_connections}") \ No newline at end of file