Adding OpenAI compatibility

New Endpoints
New Requirements
This commit is contained in:
Alpha Nerd 2025-08-28 09:40:33 +02:00 committed by GitHub
parent cdb4485334
commit 516ec8b102
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 203 additions and 6 deletions

View file

@ -2,13 +2,16 @@ annotated-types==0.7.0
anyio==4.10.0
certifi==2025.8.3
click==8.2.1
distro==1.9.0
exceptiongroup==1.3.0
fastapi==0.116.1
h11==0.16.0
httpcore==1.0.9
httpx==0.28.1
idna==3.10
jiter==0.10.0
ollama==0.5.3
openai==1.102.0
pydantic==2.11.7
pydantic-settings==2.10.1
pydantic_core==2.33.2
@ -16,6 +19,7 @@ python-dotenv==1.1.1
PyYAML==6.0.2
sniffio==1.3.1
starlette==0.47.2
tqdm==4.67.1
typing-inspection==0.4.1
typing_extensions==4.14.1
uvicorn==0.35.0

205
router.py
View file

@ -6,7 +6,7 @@ version: 0.1
license: AGPL
"""
# -------------------------------------------------------------
import json, random, asyncio, yaml, httpx, ollama
import json, random, asyncio, yaml, httpx, ollama, openai
from pathlib import Path
from typing import Dict, Set, List
from fastapi import FastAPI, Request, HTTPException
@ -314,7 +314,7 @@ async def chat_proxy(request: Request):
)
# -------------------------------------------------------------
# 8. API route Embedding
# 8. API route Embedding - deprecated
# -------------------------------------------------------------
@app.post("/api/embeddings")
async def embedding_proxy(request: Request):
@ -742,12 +742,205 @@ async def ps_proxy(request: Request):
)
# -------------------------------------------------------------
# 18. OpenAI API compatible endpoints #ToDo
# 18. API route OpenAI compatible Embedding
# -------------------------------------------------------------
@app.post("/v1/embeddings")
async def openai_embedding_proxy(request: Request):
"""
Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
input = payload.get("input")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not input:
raise HTTPException(
status_code=400, detail="Missing required field 'input'"
)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
# 3. Async generator that streams embedding data and decrements the counter
async_gen = await oclient.embeddings.create(input = [input], model=model)
await decrement_usage(endpoint, model)
# 5. Return a StreamingResponse backed by the generator
return async_gen
# -------------------------------------------------------------
# 19. API route OpenAI compatible Chat Completions
# -------------------------------------------------------------
@app.post("/v1/chat/completions")
async def openai_chat_completions_proxy(request: Request):
"""
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
messages = payload.get("messages")
frequency_penalty = payload.get("frequency_penalty")
presence_penalty = payload.get("presence_penalty")
response_format = payload.get("response_format")
seed = payload.get("seed")
stop = payload.get("stop")
stream = payload.get("stream")
stream_options = payload.get("stream_options")
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
tools =payload.get("tools")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not isinstance(messages, list):
raise HTTPException(
status_code=400, detail="Missing required field 'messages' (must be a list)"
)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
# 3. Async generator that streams completions data and decrements the counter
async def stream_ochat_response():
try:
# The chat method returns a generator of dicts (or GenerateResponse)
async_gen = await oclient.chat.completions.create(messages=messages, model=model, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, response_format=response_format, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, tools=tools)
if stream == True:
async for chunk in async_gen:
data = (
chunk.model_dump_json()
if hasattr(chunk, "model_dump_json")
else json.dumps(chunk)
)
yield f"data: {data}\n\n".encode("utf-8")
# Final DONE event
yield b"data: [DONE]\n\n"
else:
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen)
)
yield json_line.encode("utf-8") + b"\n"
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_ochat_response(),
media_type="application/json",
)
# -------------------------------------------------------------
# 20. API route OpenAI compatible Completions
# -------------------------------------------------------------
@app.post("/v1/completions")
async def openai_completions_proxy(request: Request):
"""
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = json.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
prompt = payload.get("prompt")
frequency_penalty = payload.get("frequency_penalty")
presence_penalty = payload.get("presence_penalty")
seed = payload.get("seed")
stop = payload.get("stop")
stream = payload.get("stream")
stream_options = payload.get("stream_options")
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
suffix =payload.get("suffix")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not prompt:
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
oclient = openai.AsyncOpenAI(base_url=endpoint+"/v1", api_key="ollama")
# 3. Async generator that streams completions data and decrements the counter
async def stream_ocompletions_response():
try:
# The chat method returns a generator of dicts (or GenerateResponse)
async_gen = await oclient.completions.create(model=model, prompt=prompt, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, seed=seed, stop=stop, stream=stream, stream_options=stream_options, temperature=temperature, top_p=top_p, max_tokens=max_tokens, suffix=suffix)
if stream == True:
async for chunk in async_gen:
data = (
chunk.model_dump_json()
if hasattr(chunk, "model_dump_json")
else json.dumps(chunk)
)
yield f"data: {data}\n\n".encode("utf-8")
# Final DONE event
yield b"data: [DONE]\n\n"
else:
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
else json.dumps(async_gen)
)
yield json_line.encode("utf-8") + b"\n"
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_ocompletions_response(),
media_type="application/json",
)
# -------------------------------------------------------------
# 21. OpenAI API compatible endpoints #ToDo
# -------------------------------------------------------------
@app.post("/v1/models")
@app.post("/v1/embeddings")
async def not_implemented_yet(request: Request):
return Response(
@ -755,7 +948,7 @@ async def not_implemented_yet(request: Request):
)
# -------------------------------------------------------------
# 19. FastAPI startup event load configuration
# 22. FastAPI startup event load configuration
# -------------------------------------------------------------
@app.on_event("startup")
async def startup_event() -> None:
@ -763,4 +956,4 @@ async def startup_event() -> None:
# Load YAML config (or use defaults if not present)
config = Config.from_yaml(Path("config.yaml"))
print(f"Loaded configuration:\n endpoints={config.endpoints},\n "
f"max_concurrent_connections={config.max_concurrent_connections}")
f"max_concurrent_connections={config.max_concurrent_connections}")